| | import math |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| | import modules |
| | import commons |
| | import attentions |
| | import monotonic_align |
| |
|
| |
|
| | class DurationPredictor(nn.Module): |
| | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): |
| | super().__init__() |
| |
|
| | self.in_channels = in_channels |
| | self.filter_channels = filter_channels |
| | self.kernel_size = kernel_size |
| | self.p_dropout = p_dropout |
| |
|
| | self.drop = nn.Dropout(p_dropout) |
| | self.conv_1 = nn.Conv1d( |
| | in_channels, filter_channels, kernel_size, padding=kernel_size // 2 |
| | ) |
| | self.norm_1 = attentions.LayerNorm(filter_channels) |
| | self.conv_2 = nn.Conv1d( |
| | filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 |
| | ) |
| | self.norm_2 = attentions.LayerNorm(filter_channels) |
| | self.proj = nn.Conv1d(filter_channels, 1, 1) |
| |
|
| | def forward(self, x, x_mask): |
| | x = self.conv_1(x * x_mask) |
| | x = torch.relu(x) |
| | x = self.norm_1(x) |
| | x = self.drop(x) |
| | x = self.conv_2(x * x_mask) |
| | x = torch.relu(x) |
| | x = self.norm_2(x) |
| | x = self.drop(x) |
| | x = self.proj(x * x_mask) |
| | return x * x_mask |
| |
|
| |
|
| | class TextEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | n_vocab, |
| | out_channels, |
| | hidden_channels, |
| | filter_channels, |
| | filter_channels_dp, |
| | n_heads, |
| | n_layers, |
| | kernel_size, |
| | p_dropout, |
| | window_size=None, |
| | block_length=None, |
| | mean_only=False, |
| | prenet=False, |
| | gin_channels=0, |
| | ): |
| |
|
| | super().__init__() |
| |
|
| | self.n_vocab = n_vocab |
| | self.out_channels = out_channels |
| | self.hidden_channels = hidden_channels |
| | self.filter_channels = filter_channels |
| | self.filter_channels_dp = filter_channels_dp |
| | self.n_heads = n_heads |
| | self.n_layers = n_layers |
| | self.kernel_size = kernel_size |
| | self.p_dropout = p_dropout |
| | self.window_size = window_size |
| | self.block_length = block_length |
| | self.mean_only = mean_only |
| | self.prenet = prenet |
| | self.gin_channels = gin_channels |
| |
|
| | self.emb = nn.Embedding(n_vocab, hidden_channels) |
| | nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) |
| |
|
| | if prenet: |
| | self.pre = modules.ConvReluNorm( |
| | hidden_channels, |
| | hidden_channels, |
| | hidden_channels, |
| | kernel_size=5, |
| | n_layers=3, |
| | p_dropout=0.5, |
| | ) |
| | self.encoder = attentions.Encoder( |
| | hidden_channels, |
| | filter_channels, |
| | n_heads, |
| | n_layers, |
| | kernel_size, |
| | p_dropout, |
| | window_size=window_size, |
| | block_length=block_length, |
| | ) |
| |
|
| | self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1) |
| | if not mean_only: |
| | self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1) |
| | self.proj_w = DurationPredictor( |
| | hidden_channels + gin_channels, filter_channels_dp, kernel_size, p_dropout |
| | ) |
| |
|
| | def forward(self, x, x_lengths, g=None): |
| | x = self.emb(x) * math.sqrt(self.hidden_channels) |
| | x = torch.transpose(x, 1, -1) |
| | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( |
| | x.dtype |
| | ) |
| |
|
| | if self.prenet: |
| | x = self.pre(x, x_mask) |
| | x = self.encoder(x, x_mask) |
| |
|
| | if g is not None: |
| | g_exp = g.expand(-1, -1, x.size(-1)) |
| | x_dp = torch.cat([torch.detach(x), g_exp], 1) |
| | else: |
| | x_dp = torch.detach(x) |
| |
|
| | x_m = self.proj_m(x) * x_mask |
| | if not self.mean_only: |
| | x_logs = self.proj_s(x) * x_mask |
| | else: |
| | x_logs = torch.zeros_like(x_m) |
| |
|
| | logw = self.proj_w(x_dp, x_mask) |
| | return x_m, x_logs, logw, x_mask |
| |
|
| |
|
| | class FlowSpecDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels, |
| | hidden_channels, |
| | kernel_size, |
| | dilation_rate, |
| | n_blocks, |
| | n_layers, |
| | p_dropout=0.0, |
| | n_split=4, |
| | n_sqz=2, |
| | sigmoid_scale=False, |
| | gin_channels=0, |
| | ): |
| | super().__init__() |
| |
|
| | self.in_channels = in_channels |
| | self.hidden_channels = hidden_channels |
| | self.kernel_size = kernel_size |
| | self.dilation_rate = dilation_rate |
| | self.n_blocks = n_blocks |
| | self.n_layers = n_layers |
| | self.p_dropout = p_dropout |
| | self.n_split = n_split |
| | self.n_sqz = n_sqz |
| | self.sigmoid_scale = sigmoid_scale |
| | self.gin_channels = gin_channels |
| |
|
| | self.flows = nn.ModuleList() |
| | for b in range(n_blocks): |
| | self.flows.append(modules.ActNorm(channels=in_channels * n_sqz)) |
| | self.flows.append( |
| | modules.InvConvNear(channels=in_channels * n_sqz, n_split=n_split) |
| | ) |
| | self.flows.append( |
| | attentions.CouplingBlock( |
| | in_channels * n_sqz, |
| | hidden_channels, |
| | kernel_size=kernel_size, |
| | dilation_rate=dilation_rate, |
| | n_layers=n_layers, |
| | gin_channels=gin_channels, |
| | p_dropout=p_dropout, |
| | sigmoid_scale=sigmoid_scale, |
| | ) |
| | ) |
| |
|
| | def forward(self, x, x_mask, g=None, reverse=False): |
| | if not reverse: |
| | flows = self.flows |
| | logdet_tot = 0 |
| | else: |
| | flows = reversed(self.flows) |
| | logdet_tot = None |
| |
|
| | if self.n_sqz > 1: |
| | x, x_mask = commons.squeeze(x, x_mask, self.n_sqz) |
| | for f in flows: |
| | if not reverse: |
| | x, logdet = f(x, x_mask, g=g, reverse=reverse) |
| | logdet_tot += logdet |
| | else: |
| | x, logdet = f(x, x_mask, g=g, reverse=reverse) |
| | if self.n_sqz > 1: |
| | x, x_mask = commons.unsqueeze(x, x_mask, self.n_sqz) |
| | return x, logdet_tot |
| |
|
| | def store_inverse(self): |
| | for f in self.flows: |
| | f.store_inverse() |
| |
|
| |
|
| | class FlowGenerator(nn.Module): |
| | def __init__( |
| | self, |
| | n_vocab, |
| | hidden_channels, |
| | filter_channels, |
| | filter_channels_dp, |
| | out_channels, |
| | kernel_size=3, |
| | n_heads=2, |
| | n_layers_enc=6, |
| | p_dropout=0.0, |
| | n_blocks_dec=12, |
| | kernel_size_dec=5, |
| | dilation_rate=5, |
| | n_block_layers=4, |
| | p_dropout_dec=0.0, |
| | n_speakers=0, |
| | gin_channels=0, |
| | n_split=4, |
| | n_sqz=1, |
| | sigmoid_scale=False, |
| | window_size=None, |
| | block_length=None, |
| | mean_only=False, |
| | hidden_channels_enc=None, |
| | hidden_channels_dec=None, |
| | prenet=False, |
| | **kwargs |
| | ): |
| |
|
| | super().__init__() |
| | self.n_vocab = n_vocab |
| | self.hidden_channels = hidden_channels |
| | self.filter_channels = filter_channels |
| | self.filter_channels_dp = filter_channels_dp |
| | self.out_channels = out_channels |
| | self.kernel_size = kernel_size |
| | self.n_heads = n_heads |
| | self.n_layers_enc = n_layers_enc |
| | self.p_dropout = p_dropout |
| | self.n_blocks_dec = n_blocks_dec |
| | self.kernel_size_dec = kernel_size_dec |
| | self.dilation_rate = dilation_rate |
| | self.n_block_layers = n_block_layers |
| | self.p_dropout_dec = p_dropout_dec |
| | self.n_speakers = n_speakers |
| | self.gin_channels = gin_channels |
| | self.n_split = n_split |
| | self.n_sqz = n_sqz |
| | self.sigmoid_scale = sigmoid_scale |
| | self.window_size = window_size |
| | self.block_length = block_length |
| | self.mean_only = mean_only |
| | self.hidden_channels_enc = hidden_channels_enc |
| | self.hidden_channels_dec = hidden_channels_dec |
| | self.prenet = prenet |
| |
|
| | self.encoder = TextEncoder( |
| | n_vocab, |
| | out_channels, |
| | hidden_channels_enc or hidden_channels, |
| | filter_channels, |
| | filter_channels_dp, |
| | n_heads, |
| | n_layers_enc, |
| | kernel_size, |
| | p_dropout, |
| | window_size=window_size, |
| | block_length=block_length, |
| | mean_only=mean_only, |
| | prenet=prenet, |
| | gin_channels=gin_channels, |
| | ) |
| |
|
| | self.decoder = FlowSpecDecoder( |
| | out_channels, |
| | hidden_channels_dec or hidden_channels, |
| | kernel_size_dec, |
| | dilation_rate, |
| | n_blocks_dec, |
| | n_block_layers, |
| | p_dropout=p_dropout_dec, |
| | n_split=n_split, |
| | n_sqz=n_sqz, |
| | sigmoid_scale=sigmoid_scale, |
| | gin_channels=gin_channels, |
| | ) |
| |
|
| | if n_speakers > 1: |
| | self.emb_g = nn.Embedding(n_speakers, gin_channels) |
| | nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) |
| |
|
| | def forward( |
| | self, |
| | x, |
| | x_lengths, |
| | y=None, |
| | y_lengths=None, |
| | g=None, |
| | gen=False, |
| | noise_scale=1.0, |
| | length_scale=1.0, |
| | ): |
| | if g is not None: |
| | g = F.normalize(self.emb_g(g)).unsqueeze(-1) |
| | x_m, x_logs, logw, x_mask = self.encoder(x, x_lengths, g=g) |
| |
|
| | if gen: |
| | w = torch.exp(logw) * x_mask * length_scale |
| | w_ceil = torch.ceil(w) |
| | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() |
| | y_max_length = None |
| | else: |
| | y_max_length = y.size(2) |
| | y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length) |
| | z_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y_max_length), 1).to( |
| | x_mask.dtype |
| | ) |
| | attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2) |
| |
|
| | if gen: |
| | attn = commons.generate_path( |
| | w_ceil.squeeze(1), attn_mask.squeeze(1) |
| | ).unsqueeze(1) |
| | z_m = torch.matmul( |
| | attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2) |
| | ).transpose( |
| | 1, 2 |
| | ) |
| | z_logs = torch.matmul( |
| | attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2) |
| | ).transpose( |
| | 1, 2 |
| | ) |
| | logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask |
| |
|
| | z = (z_m + torch.exp(z_logs) * torch.randn_like(z_m) * noise_scale) * z_mask |
| | y, logdet = self.decoder(z, z_mask, g=g, reverse=True) |
| | return ( |
| | (y, z_m, z_logs, logdet, z_mask), |
| | (x_m, x_logs, x_mask), |
| | (attn, logw, logw_), |
| | ) |
| | else: |
| | z, logdet = self.decoder(y, z_mask, g=g, reverse=False) |
| | with torch.no_grad(): |
| | x_s_sq_r = torch.exp(-2 * x_logs) |
| | logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze( |
| | -1 |
| | ) |
| | logp2 = torch.matmul( |
| | x_s_sq_r.transpose(1, 2), -0.5 * (z ** 2) |
| | ) |
| | logp3 = torch.matmul( |
| | (x_m * x_s_sq_r).transpose(1, 2), z |
| | ) |
| | logp4 = torch.sum(-0.5 * (x_m ** 2) * x_s_sq_r, [1]).unsqueeze( |
| | -1 |
| | ) |
| | logp = logp1 + logp2 + logp3 + logp4 |
| |
|
| | attn = ( |
| | monotonic_align.maximum_path(logp, attn_mask.squeeze(1)) |
| | .unsqueeze(1) |
| | .detach() |
| | ) |
| | z_m = torch.matmul( |
| | attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2) |
| | ).transpose( |
| | 1, 2 |
| | ) |
| | z_logs = torch.matmul( |
| | attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2) |
| | ).transpose( |
| | 1, 2 |
| | ) |
| | logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask |
| | return ( |
| | (z, z_m, z_logs, logdet, z_mask), |
| | (x_m, x_logs, x_mask), |
| | (attn, logw, logw_), |
| | ) |
| |
|
| | def preprocess(self, y, y_lengths, y_max_length): |
| | if y_max_length is not None: |
| | y_max_length = (y_max_length // self.n_sqz) * self.n_sqz |
| | y = y[:, :, :y_max_length] |
| | y_lengths = (y_lengths // self.n_sqz) * self.n_sqz |
| | return y, y_lengths, y_max_length |
| |
|
| | def store_inverse(self): |
| | self.decoder.store_inverse() |
| |
|