noblebarkrr's picture
Updated to Dzeta
4f175c5
import torch
from typing import Optional
from i18n import _i18n
from .generators.hifigan_mrf import HiFiGANMRFGenerator
from .generators.hifigan_nsf import HiFiGANNSFGenerator
from .generators.hifigan import HiFiGANGenerator
from .generators.refinegan import RefineGANGenerator
from .commons import slice_segments, rand_slice_segments
from .residuals import ResidualCouplingBlock
from .encoders import TextEncoder, PosteriorEncoder
class Synthesizer(torch.nn.Module):
def __init__(
self,
spec_channels: int,
segment_size: int,
inter_channels: int,
hidden_channels: int,
filter_channels: int,
n_heads: int,
n_layers: int,
kernel_size: int,
p_dropout: float,
resblock: str,
resblock_kernel_sizes: list,
resblock_dilation_sizes: list,
upsample_rates: list,
upsample_initial_channel: int,
upsample_kernel_sizes: list,
spk_embed_dim: int,
gin_channels: int,
sr: int,
use_f0: bool,
text_enc_hidden_dim: int = 768,
vocoder: str = "HiFi-GAN",
randomized: bool = True,
checkpointing: bool = False,
**kwargs,
):
super().__init__()
self.segment_size = segment_size
self.use_f0 = use_f0
self.randomized = randomized
self.enc_p = TextEncoder(
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
text_enc_hidden_dim,
f0=use_f0,
)
print(_i18n("using_vocoder")+": "+vocoder)
if use_f0:
if vocoder == "MRF HiFi-GAN":
self.dec = HiFiGANMRFGenerator(
in_channel=inter_channels,
upsample_initial_channel=upsample_initial_channel,
upsample_rates=upsample_rates,
upsample_kernel_sizes=upsample_kernel_sizes,
resblock_kernel_sizes=resblock_kernel_sizes,
resblock_dilations=resblock_dilation_sizes,
gin_channels=gin_channels,
sample_rate=sr,
harmonic_num=8,
checkpointing=checkpointing,
)
elif vocoder == "RefineGAN":
self.dec = RefineGANGenerator(
sample_rate=sr,
downsample_rates=upsample_rates[::-1],
upsample_rates=upsample_rates,
start_channels=16,
num_mels=inter_channels,
checkpointing=checkpointing,
)
else:
self.dec = HiFiGANNSFGenerator(
inter_channels,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
sr=sr,
checkpointing=checkpointing,
)
else:
if vocoder == "MRF HiFi-GAN":
print("MRF HiFi-GAN does not support training without pitch guidance.")
self.dec = None
elif vocoder == "RefineGAN":
print("RefineGAN does not support training without pitch guidance.")
self.dec = None
else:
self.dec = HiFiGANGenerator(
inter_channels,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
self.flow = ResidualCouplingBlock(
inter_channels,
hidden_channels,
5,
1,
3,
gin_channels=gin_channels,
)
self.emb_g = torch.nn.Embedding(spk_embed_dim, gin_channels)
def _remove_weight_norm_from(self, module):
for hook in module._forward_pre_hooks.values():
if getattr(hook, "__class__", None).__name__ == "WeightNorm":
torch.nn.utils.remove_weight_norm(module)
def remove_weight_norm(self):
for module in [self.dec, self.flow, self.enc_q]:
self._remove_weight_norm_from(module)
def __prepare_scriptable__(self):
self.remove_weight_norm()
return self
def forward(
self,
phone: torch.Tensor,
phone_lengths: torch.Tensor,
pitch: Optional[torch.Tensor] = None,
pitchf: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
y_lengths: Optional[torch.Tensor] = None,
ds: Optional[torch.Tensor] = None,
):
g = self.emb_g(ds).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
if y is not None:
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
if self.randomized:
z_slice, ids_slice = rand_slice_segments(
z, y_lengths, self.segment_size
)
if self.use_f0:
pitchf = slice_segments(pitchf, ids_slice, self.segment_size, 2)
o = self.dec(z_slice, pitchf, g=g)
else:
o = self.dec(z_slice, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
else:
if self.use_f0:
o = self.dec(z, pitchf, g=g)
else:
o = self.dec(z, g=g)
return o, None, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
else:
return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
@torch.jit.export
def infer(
self,
phone: torch.Tensor,
phone_lengths: torch.Tensor,
pitch: Optional[torch.Tensor] = None,
nsff0: Optional[torch.Tensor] = None,
sid: torch.Tensor = None,
rate: Optional[torch.Tensor] = None,
):
g = self.emb_g(sid).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
if rate is not None:
head = int(z_p.shape[2] * (1.0 - rate.item()))
z_p, x_mask = z_p[:, :, head:], x_mask[:, :, head:]
if self.use_f0 and nsff0 is not None:
nsff0 = nsff0[:, head:]
z = self.flow(z_p, x_mask, g=g, reverse=True)
o = (
self.dec(z * x_mask, nsff0, g=g)
if self.use_f0
else self.dec(z * x_mask, g=g)
)
return o, x_mask, (z, z_p, m_p, logs_p)