| import torch |
| from typing import Optional |
| from i18n import _i18n |
| from .generators.hifigan_mrf import HiFiGANMRFGenerator |
| from .generators.hifigan_nsf import HiFiGANNSFGenerator |
| from .generators.hifigan import HiFiGANGenerator |
| from .generators.refinegan import RefineGANGenerator |
| from .commons import slice_segments, rand_slice_segments |
| from .residuals import ResidualCouplingBlock |
| from .encoders import TextEncoder, PosteriorEncoder |
|
|
|
|
| class Synthesizer(torch.nn.Module): |
|
|
| def __init__( |
| self, |
| spec_channels: int, |
| segment_size: int, |
| inter_channels: int, |
| hidden_channels: int, |
| filter_channels: int, |
| n_heads: int, |
| n_layers: int, |
| kernel_size: int, |
| p_dropout: float, |
| resblock: str, |
| resblock_kernel_sizes: list, |
| resblock_dilation_sizes: list, |
| upsample_rates: list, |
| upsample_initial_channel: int, |
| upsample_kernel_sizes: list, |
| spk_embed_dim: int, |
| gin_channels: int, |
| sr: int, |
| use_f0: bool, |
| text_enc_hidden_dim: int = 768, |
| vocoder: str = "HiFi-GAN", |
| randomized: bool = True, |
| checkpointing: bool = False, |
| **kwargs, |
| ): |
| super().__init__() |
| self.segment_size = segment_size |
| self.use_f0 = use_f0 |
| self.randomized = randomized |
|
|
| self.enc_p = TextEncoder( |
| inter_channels, |
| hidden_channels, |
| filter_channels, |
| n_heads, |
| n_layers, |
| kernel_size, |
| p_dropout, |
| text_enc_hidden_dim, |
| f0=use_f0, |
| ) |
| print(_i18n("using_vocoder")+": "+vocoder) |
| if use_f0: |
| if vocoder == "MRF HiFi-GAN": |
| self.dec = HiFiGANMRFGenerator( |
| in_channel=inter_channels, |
| upsample_initial_channel=upsample_initial_channel, |
| upsample_rates=upsample_rates, |
| upsample_kernel_sizes=upsample_kernel_sizes, |
| resblock_kernel_sizes=resblock_kernel_sizes, |
| resblock_dilations=resblock_dilation_sizes, |
| gin_channels=gin_channels, |
| sample_rate=sr, |
| harmonic_num=8, |
| checkpointing=checkpointing, |
| ) |
| elif vocoder == "RefineGAN": |
| self.dec = RefineGANGenerator( |
| sample_rate=sr, |
| downsample_rates=upsample_rates[::-1], |
| upsample_rates=upsample_rates, |
| start_channels=16, |
| num_mels=inter_channels, |
| checkpointing=checkpointing, |
| ) |
| else: |
| self.dec = HiFiGANNSFGenerator( |
| inter_channels, |
| resblock_kernel_sizes, |
| resblock_dilation_sizes, |
| upsample_rates, |
| upsample_initial_channel, |
| upsample_kernel_sizes, |
| gin_channels=gin_channels, |
| sr=sr, |
| checkpointing=checkpointing, |
| ) |
| else: |
| if vocoder == "MRF HiFi-GAN": |
| print("MRF HiFi-GAN does not support training without pitch guidance.") |
| self.dec = None |
| elif vocoder == "RefineGAN": |
| print("RefineGAN does not support training without pitch guidance.") |
| self.dec = None |
| else: |
| self.dec = HiFiGANGenerator( |
| inter_channels, |
| resblock_kernel_sizes, |
| resblock_dilation_sizes, |
| upsample_rates, |
| upsample_initial_channel, |
| upsample_kernel_sizes, |
| gin_channels=gin_channels, |
| ) |
| self.enc_q = PosteriorEncoder( |
| spec_channels, |
| inter_channels, |
| hidden_channels, |
| 5, |
| 1, |
| 16, |
| gin_channels=gin_channels, |
| ) |
| self.flow = ResidualCouplingBlock( |
| inter_channels, |
| hidden_channels, |
| 5, |
| 1, |
| 3, |
| gin_channels=gin_channels, |
| ) |
| self.emb_g = torch.nn.Embedding(spk_embed_dim, gin_channels) |
|
|
| def _remove_weight_norm_from(self, module): |
| for hook in module._forward_pre_hooks.values(): |
| if getattr(hook, "__class__", None).__name__ == "WeightNorm": |
| torch.nn.utils.remove_weight_norm(module) |
|
|
| def remove_weight_norm(self): |
| for module in [self.dec, self.flow, self.enc_q]: |
| self._remove_weight_norm_from(module) |
|
|
| def __prepare_scriptable__(self): |
| self.remove_weight_norm() |
| return self |
|
|
| def forward( |
| self, |
| phone: torch.Tensor, |
| phone_lengths: torch.Tensor, |
| pitch: Optional[torch.Tensor] = None, |
| pitchf: Optional[torch.Tensor] = None, |
| y: Optional[torch.Tensor] = None, |
| y_lengths: Optional[torch.Tensor] = None, |
| ds: Optional[torch.Tensor] = None, |
| ): |
| g = self.emb_g(ds).unsqueeze(-1) |
| m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) |
|
|
| if y is not None: |
| z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) |
| z_p = self.flow(z, y_mask, g=g) |
| if self.randomized: |
| z_slice, ids_slice = rand_slice_segments( |
| z, y_lengths, self.segment_size |
| ) |
| if self.use_f0: |
| pitchf = slice_segments(pitchf, ids_slice, self.segment_size, 2) |
| o = self.dec(z_slice, pitchf, g=g) |
| else: |
| o = self.dec(z_slice, g=g) |
| return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) |
| else: |
| if self.use_f0: |
| o = self.dec(z, pitchf, g=g) |
| else: |
| o = self.dec(z, g=g) |
| return o, None, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) |
| else: |
| return None, None, x_mask, None, (None, None, m_p, logs_p, None, None) |
|
|
| @torch.jit.export |
| def infer( |
| self, |
| phone: torch.Tensor, |
| phone_lengths: torch.Tensor, |
| pitch: Optional[torch.Tensor] = None, |
| nsff0: Optional[torch.Tensor] = None, |
| sid: torch.Tensor = None, |
| rate: Optional[torch.Tensor] = None, |
| ): |
| g = self.emb_g(sid).unsqueeze(-1) |
| m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) |
| z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask |
|
|
| if rate is not None: |
| head = int(z_p.shape[2] * (1.0 - rate.item())) |
| z_p, x_mask = z_p[:, :, head:], x_mask[:, :, head:] |
| if self.use_f0 and nsff0 is not None: |
| nsff0 = nsff0[:, head:] |
|
|
| z = self.flow(z_p, x_mask, g=g, reverse=True) |
| o = ( |
| self.dec(z * x_mask, nsff0, g=g) |
| if self.use_f0 |
| else self.dec(z * x_mask, g=g) |
| ) |
|
|
| return o, x_mask, (z, z_p, m_p, logs_p) |
|
|