import torch from typing import Optional from i18n import _i18n from .generators.hifigan_mrf import HiFiGANMRFGenerator from .generators.hifigan_nsf import HiFiGANNSFGenerator from .generators.hifigan import HiFiGANGenerator from .generators.refinegan import RefineGANGenerator from .commons import slice_segments, rand_slice_segments from .residuals import ResidualCouplingBlock from .encoders import TextEncoder, PosteriorEncoder class Synthesizer(torch.nn.Module): def __init__( self, spec_channels: int, segment_size: int, inter_channels: int, hidden_channels: int, filter_channels: int, n_heads: int, n_layers: int, kernel_size: int, p_dropout: float, resblock: str, resblock_kernel_sizes: list, resblock_dilation_sizes: list, upsample_rates: list, upsample_initial_channel: int, upsample_kernel_sizes: list, spk_embed_dim: int, gin_channels: int, sr: int, use_f0: bool, text_enc_hidden_dim: int = 768, vocoder: str = "HiFi-GAN", randomized: bool = True, checkpointing: bool = False, **kwargs, ): super().__init__() self.segment_size = segment_size self.use_f0 = use_f0 self.randomized = randomized self.enc_p = TextEncoder( inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, text_enc_hidden_dim, f0=use_f0, ) print(_i18n("using_vocoder")+": "+vocoder) if use_f0: if vocoder == "MRF HiFi-GAN": self.dec = HiFiGANMRFGenerator( in_channel=inter_channels, upsample_initial_channel=upsample_initial_channel, upsample_rates=upsample_rates, upsample_kernel_sizes=upsample_kernel_sizes, resblock_kernel_sizes=resblock_kernel_sizes, resblock_dilations=resblock_dilation_sizes, gin_channels=gin_channels, sample_rate=sr, harmonic_num=8, checkpointing=checkpointing, ) elif vocoder == "RefineGAN": self.dec = RefineGANGenerator( sample_rate=sr, downsample_rates=upsample_rates[::-1], upsample_rates=upsample_rates, start_channels=16, num_mels=inter_channels, checkpointing=checkpointing, ) else: self.dec = HiFiGANNSFGenerator( inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, checkpointing=checkpointing, ) else: if vocoder == "MRF HiFi-GAN": print("MRF HiFi-GAN does not support training without pitch guidance.") self.dec = None elif vocoder == "RefineGAN": print("RefineGAN does not support training without pitch guidance.") self.dec = None else: self.dec = HiFiGANGenerator( inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, ) self.enc_q = PosteriorEncoder( spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels, ) self.flow = ResidualCouplingBlock( inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels, ) self.emb_g = torch.nn.Embedding(spk_embed_dim, gin_channels) def _remove_weight_norm_from(self, module): for hook in module._forward_pre_hooks.values(): if getattr(hook, "__class__", None).__name__ == "WeightNorm": torch.nn.utils.remove_weight_norm(module) def remove_weight_norm(self): for module in [self.dec, self.flow, self.enc_q]: self._remove_weight_norm_from(module) def __prepare_scriptable__(self): self.remove_weight_norm() return self def forward( self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: Optional[torch.Tensor] = None, pitchf: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, y_lengths: Optional[torch.Tensor] = None, ds: Optional[torch.Tensor] = None, ): g = self.emb_g(ds).unsqueeze(-1) m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) if y is not None: z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z_p = self.flow(z, y_mask, g=g) if self.randomized: z_slice, ids_slice = rand_slice_segments( z, y_lengths, self.segment_size ) if self.use_f0: pitchf = slice_segments(pitchf, ids_slice, self.segment_size, 2) o = self.dec(z_slice, pitchf, g=g) else: o = self.dec(z_slice, g=g) return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) else: if self.use_f0: o = self.dec(z, pitchf, g=g) else: o = self.dec(z, g=g) return o, None, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) else: return None, None, x_mask, None, (None, None, m_p, logs_p, None, None) @torch.jit.export def infer( self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: Optional[torch.Tensor] = None, nsff0: Optional[torch.Tensor] = None, sid: torch.Tensor = None, rate: Optional[torch.Tensor] = None, ): g = self.emb_g(sid).unsqueeze(-1) m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask if rate is not None: head = int(z_p.shape[2] * (1.0 - rate.item())) z_p, x_mask = z_p[:, :, head:], x_mask[:, :, head:] if self.use_f0 and nsff0 is not None: nsff0 = nsff0[:, head:] z = self.flow(z_p, x_mask, g=g, reverse=True) o = ( self.dec(z * x_mask, nsff0, g=g) if self.use_f0 else self.dec(z * x_mask, g=g) ) return o, x_mask, (z, z_p, m_p, logs_p)