| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm |
| |
|
| | from nemo.collections.tts.modules.hifigan_modules import ResBlock1, ResBlock2, get_padding, init_weights |
| | from nemo.collections.tts.modules.monotonic_align import maximum_path |
| | from nemo.collections.tts.parts.utils.helpers import ( |
| | convert_pad_shape, |
| | generate_path, |
| | get_mask_from_lengths, |
| | rand_slice_segments, |
| | ) |
| | from nemo.collections.tts.parts.utils.splines import piecewise_rational_quadratic_transform |
| |
|
| | LRELU_SLOPE = 0.1 |
| |
|
| |
|
| | @torch.jit.script |
| | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): |
| | n_channels_int = n_channels[0] |
| | in_act = input_a + input_b |
| | t_act = torch.tanh(in_act[:, :n_channels_int, :]) |
| | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) |
| | acts = t_act * s_act |
| | return acts |
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| | def __init__(self, channels, eps=1e-5): |
| | super().__init__() |
| | self.channels = channels |
| | self.eps = eps |
| |
|
| | self.gamma = nn.Parameter(torch.ones(channels)) |
| | self.beta = nn.Parameter(torch.zeros(channels)) |
| |
|
| | def forward(self, x): |
| | x = x.transpose(1, -1) |
| | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) |
| | return x.transpose(1, -1) |
| |
|
| |
|
| | class ConvReluNorm(nn.Module): |
| | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.hidden_channels = hidden_channels |
| | self.out_channels = out_channels |
| | self.kernel_size = kernel_size |
| | self.n_layers = n_layers |
| | self.p_dropout = p_dropout |
| | assert n_layers > 1, "Number of layers should be larger than 0." |
| |
|
| | self.conv_layers = nn.ModuleList() |
| | self.norm_layers = nn.ModuleList() |
| | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) |
| | self.norm_layers.append(LayerNorm(hidden_channels)) |
| | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) |
| | for _ in range(n_layers - 1): |
| | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) |
| | self.norm_layers.append(LayerNorm(hidden_channels)) |
| | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) |
| | self.proj.weight.data.zero_() |
| | self.proj.bias.data.zero_() |
| |
|
| | def forward(self, x, x_mask): |
| | x_org = x |
| | for i in range(self.n_layers): |
| | x = self.conv_layers[i](x * x_mask) |
| | x = self.norm_layers[i](x) |
| | x = self.relu_drop(x) |
| | x = x_org + self.proj(x) |
| | return x * x_mask |
| |
|
| |
|
| | class DDSConv(nn.Module): |
| | """ |
| | Dilated and Depth-Separable Convolution |
| | """ |
| |
|
| | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): |
| | super().__init__() |
| | self.channels = channels |
| | self.kernel_size = kernel_size |
| | self.n_layers = n_layers |
| | self.p_dropout = p_dropout |
| |
|
| | self.drop = nn.Dropout(p_dropout) |
| | self.convs_sep = nn.ModuleList() |
| | self.convs_1x1 = nn.ModuleList() |
| | self.norms_1 = nn.ModuleList() |
| | self.norms_2 = nn.ModuleList() |
| | for i in range(n_layers): |
| | dilation = kernel_size ** i |
| | padding = (kernel_size * dilation - dilation) // 2 |
| | self.convs_sep.append( |
| | nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) |
| | ) |
| | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) |
| | self.norms_1.append(LayerNorm(channels)) |
| | self.norms_2.append(LayerNorm(channels)) |
| |
|
| | def forward(self, x, x_mask, g=None): |
| | if g is not None: |
| | x = x + g |
| | for i in range(self.n_layers): |
| | y = self.convs_sep[i](x * x_mask) |
| | y = self.norms_1[i](y) |
| | y = F.gelu(y) |
| | y = self.convs_1x1[i](y) |
| | y = self.norms_2[i](y) |
| | y = F.gelu(y) |
| | y = self.drop(y) |
| | x = x + y |
| | return x * x_mask |
| |
|
| |
|
| | class WN(torch.nn.Module): |
| | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): |
| | super(WN, self).__init__() |
| | assert kernel_size % 2 == 1 |
| | self.hidden_channels = hidden_channels |
| | self.kernel_size = (kernel_size,) |
| | self.dilation_rate = dilation_rate |
| | self.n_layers = n_layers |
| | self.gin_channels = gin_channels |
| | self.p_dropout = p_dropout |
| |
|
| | self.in_layers = torch.nn.ModuleList() |
| | self.res_skip_layers = torch.nn.ModuleList() |
| | self.drop = nn.Dropout(p_dropout) |
| |
|
| | if gin_channels != 0: |
| | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) |
| | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') |
| |
|
| | for i in range(n_layers): |
| | dilation = dilation_rate ** i |
| | padding = int((kernel_size * dilation - dilation) / 2) |
| | in_layer = torch.nn.Conv1d( |
| | hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding |
| | ) |
| | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') |
| | self.in_layers.append(in_layer) |
| |
|
| | |
| | if i < n_layers - 1: |
| | res_skip_channels = 2 * hidden_channels |
| | else: |
| | res_skip_channels = hidden_channels |
| |
|
| | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) |
| | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') |
| | self.res_skip_layers.append(res_skip_layer) |
| |
|
| | def forward(self, x, x_mask, g=None, **kwargs): |
| | output = torch.zeros_like(x) |
| | n_channels_tensor = torch.IntTensor([self.hidden_channels]) |
| |
|
| | if g is not None: |
| | g = self.cond_layer(g) |
| |
|
| | for i in range(self.n_layers): |
| | x_in = self.in_layers[i](x) |
| | if g is not None: |
| | cond_offset = i * 2 * self.hidden_channels |
| | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] |
| | else: |
| | g_l = torch.zeros_like(x_in) |
| |
|
| | acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) |
| | acts = self.drop(acts) |
| |
|
| | res_skip_acts = self.res_skip_layers[i](acts) |
| | if i < self.n_layers - 1: |
| | res_acts = res_skip_acts[:, : self.hidden_channels, :] |
| | x = (x + res_acts) * x_mask |
| | output = output + res_skip_acts[:, self.hidden_channels :, :] |
| | else: |
| | output = output + res_skip_acts |
| | return output * x_mask |
| |
|
| | def remove_weight_norm(self): |
| | if self.gin_channels != 0: |
| | torch.nn.utils.remove_weight_norm(self.cond_layer) |
| | for l in self.in_layers: |
| | torch.nn.utils.remove_weight_norm(l) |
| | for l in self.res_skip_layers: |
| | torch.nn.utils.remove_weight_norm(l) |
| |
|
| |
|
| | class Log(nn.Module): |
| | def forward(self, x, x_mask, reverse=False, **kwargs): |
| | if not reverse: |
| | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask |
| | logdet = torch.sum(-y, [1, 2]) |
| | return y, logdet |
| | else: |
| | x = torch.exp(x) * x_mask |
| | return x |
| |
|
| |
|
| | class Flip(nn.Module): |
| | def forward(self, x, *args, reverse=False, **kwargs): |
| | x = torch.flip(x, [1]) |
| | if not reverse: |
| | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) |
| | return x, logdet |
| | else: |
| | return x |
| |
|
| |
|
| | class ElementwiseAffine(nn.Module): |
| | def __init__(self, channels): |
| | super().__init__() |
| | self.channels = channels |
| | self.m = nn.Parameter(torch.zeros(channels, 1)) |
| | self.logs = nn.Parameter(torch.zeros(channels, 1)) |
| |
|
| | def forward(self, x, x_mask, reverse=False, **kwargs): |
| | if not reverse: |
| | y = self.m + torch.exp(self.logs) * x |
| | y = y * x_mask |
| | logdet = torch.sum(self.logs * x_mask, [1, 2]) |
| | return y, logdet |
| | else: |
| | x = (x - self.m) * torch.exp(-self.logs) * x_mask |
| | return x |
| |
|
| |
|
| | class ResidualCouplingLayer(nn.Module): |
| | def __init__( |
| | self, |
| | channels, |
| | hidden_channels, |
| | kernel_size, |
| | dilation_rate, |
| | n_layers, |
| | p_dropout=0, |
| | gin_channels=0, |
| | mean_only=False, |
| | ): |
| | assert channels % 2 == 0, "channels should be divisible by 2" |
| | super().__init__() |
| | self.channels = channels |
| | self.hidden_channels = hidden_channels |
| | self.kernel_size = kernel_size |
| | self.dilation_rate = dilation_rate |
| | self.n_layers = n_layers |
| | self.half_channels = channels // 2 |
| | self.mean_only = mean_only |
| |
|
| | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) |
| | self.enc = WN( |
| | hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels |
| | ) |
| | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) |
| | self.post.weight.data.zero_() |
| | self.post.bias.data.zero_() |
| |
|
| | def forward(self, x, x_mask, g=None, reverse=False): |
| | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) |
| | h = self.pre(x0) * x_mask |
| | h = self.enc(h, x_mask, g=g) |
| | stats = self.post(h) * x_mask |
| | if not self.mean_only: |
| | m, logs = torch.split(stats, [self.half_channels] * 2, 1) |
| | else: |
| | m = stats |
| | logs = torch.zeros_like(m) |
| |
|
| | if not reverse: |
| | x1 = m + x1 * torch.exp(logs) * x_mask |
| | x = torch.cat([x0, x1], 1) |
| | logdet = torch.sum(logs, [1, 2]) |
| | return x, logdet |
| | else: |
| | x1 = (x1 - m) * torch.exp(-logs) * x_mask |
| | x = torch.cat([x0, x1], 1) |
| | return x |
| |
|
| |
|
| | class ConvFlow(nn.Module): |
| | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.filter_channels = filter_channels |
| | self.kernel_size = kernel_size |
| | self.n_layers = n_layers |
| | self.num_bins = num_bins |
| | self.tail_bound = tail_bound |
| | self.half_channels = in_channels // 2 |
| |
|
| | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) |
| | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) |
| | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) |
| | self.proj.weight.data.zero_() |
| | self.proj.bias.data.zero_() |
| |
|
| | def forward(self, x, x_mask, g=None, reverse=False): |
| | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) |
| | h = self.pre(x0) |
| | h = self.convs(h, x_mask, g=g) |
| | h = self.proj(h) * x_mask |
| |
|
| | b, c, t = x0.shape |
| | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) |
| |
|
| | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) |
| | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) |
| | unnormalized_derivatives = h[..., 2 * self.num_bins :] |
| |
|
| | x1, logabsdet = piecewise_rational_quadratic_transform( |
| | x1, |
| | unnormalized_widths, |
| | unnormalized_heights, |
| | unnormalized_derivatives, |
| | inverse=reverse, |
| | tails='linear', |
| | tail_bound=self.tail_bound, |
| | ) |
| |
|
| | x = torch.cat([x0, x1], 1) * x_mask |
| | logdet = torch.sum(logabsdet * x_mask, [1, 2]) |
| | if not reverse: |
| | return x, logdet |
| | else: |
| | return x |
| |
|
| |
|
| | class StochasticDurationPredictor(nn.Module): |
| | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): |
| | super().__init__() |
| | filter_channels = in_channels |
| | self.in_channels = in_channels |
| | self.filter_channels = filter_channels |
| | self.kernel_size = kernel_size |
| | self.p_dropout = p_dropout |
| | self.n_flows = n_flows |
| | self.gin_channels = gin_channels |
| |
|
| | self.log_flow = Log() |
| | self.flows = nn.ModuleList() |
| | self.flows.append(ElementwiseAffine(2)) |
| | for i in range(n_flows): |
| | self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) |
| | self.flows.append(Flip()) |
| |
|
| | self.post_pre = nn.Conv1d(1, filter_channels, 1) |
| | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) |
| | self.post_convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) |
| | self.post_flows = nn.ModuleList() |
| | self.post_flows.append(ElementwiseAffine(2)) |
| | for i in range(4): |
| | self.post_flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) |
| | self.post_flows.append(Flip()) |
| |
|
| | self.pre = nn.Conv1d(in_channels, filter_channels, 1) |
| | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) |
| | self.convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) |
| | if gin_channels != 0: |
| | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) |
| |
|
| | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): |
| | x = torch.detach(x) |
| | x = self.pre(x) |
| | if g is not None: |
| | g = torch.detach(g) |
| | x = x + self.cond(g) |
| | x = self.convs(x, x_mask) |
| | x = self.proj(x) * x_mask |
| |
|
| | |
| | |
| | if not reverse: |
| | flows = self.flows |
| | assert w is not None |
| |
|
| | logdet_tot_q = 0 |
| | h_w = self.post_pre(w) |
| | h_w = self.post_convs(h_w, x_mask) |
| | h_w = self.post_proj(h_w) * x_mask |
| | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask |
| | z_q = e_q |
| | for flow in self.post_flows: |
| | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) |
| | logdet_tot_q += logdet_q |
| | z_u, z1 = torch.split(z_q, [1, 1], 1) |
| | u = torch.sigmoid(z_u) * x_mask |
| | z0 = (w - u) * x_mask |
| | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) |
| | logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q |
| |
|
| | logdet_tot = 0 |
| | z0, logdet = self.log_flow(z0, x_mask) |
| | logdet_tot += logdet |
| | z = torch.cat([z0, z1], 1) |
| | for flow in flows: |
| | z, logdet = flow(z, x_mask, g=x, reverse=reverse) |
| | logdet_tot = logdet_tot + logdet |
| | nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot |
| | return nll + logq |
| | else: |
| | flows = list(reversed(self.flows)) |
| | flows = flows[:-2] + [flows[-1]] |
| | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale |
| | for flow in flows: |
| | z = flow(z, x_mask, g=x, reverse=reverse) |
| | z0, z1 = torch.split(z, [1, 1], 1) |
| | logw = z0 |
| | return logw |
| |
|
| |
|
| | class DurationPredictor(nn.Module): |
| | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): |
| | super().__init__() |
| |
|
| | self.in_channels = in_channels |
| | self.filter_channels = filter_channels |
| | self.kernel_size = kernel_size |
| | self.p_dropout = p_dropout |
| | self.gin_channels = gin_channels |
| |
|
| | self.drop = nn.Dropout(p_dropout) |
| | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) |
| | self.norm_1 = LayerNorm(filter_channels) |
| | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) |
| | self.norm_2 = LayerNorm(filter_channels) |
| | self.proj = nn.Conv1d(filter_channels, 1, 1) |
| |
|
| | if gin_channels != 0: |
| | self.cond = nn.Conv1d(gin_channels, in_channels, 1) |
| |
|
| | def forward(self, x, x_mask, g=None): |
| | x = torch.detach(x) |
| | if g is not None: |
| | g = torch.detach(g) |
| | x = x + self.cond(g) |
| | x = self.conv_1(x * x_mask) |
| | x = torch.relu(x) |
| | x = self.norm_1(x) |
| | x = self.drop(x) |
| | x = self.conv_2(x * x_mask) |
| | x = torch.relu(x) |
| | x = self.norm_2(x) |
| | x = self.drop(x) |
| | x = self.proj(x * x_mask) |
| | return x * x_mask |
| |
|
| |
|
| | class TextEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | n_vocab, |
| | out_channels, |
| | hidden_channels, |
| | filter_channels, |
| | n_heads, |
| | n_layers, |
| | kernel_size, |
| | p_dropout, |
| | padding_idx, |
| | ): |
| | super().__init__() |
| | self.n_vocab = n_vocab |
| | self.out_channels = out_channels |
| | self.hidden_channels = hidden_channels |
| | self.filter_channels = filter_channels |
| | self.n_heads = n_heads |
| | self.n_layers = n_layers |
| | self.kernel_size = kernel_size |
| | self.p_dropout = p_dropout |
| |
|
| | self.emb = nn.Embedding(n_vocab, hidden_channels, padding_idx=padding_idx) |
| | nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) |
| |
|
| | self.encoder = AttentionEncoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout) |
| | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) |
| |
|
| | def forward(self, x, x_lengths): |
| | x = self.emb(x) * math.sqrt(self.hidden_channels) |
| | x = torch.transpose(x, 1, -1) |
| | x_mask = torch.unsqueeze(get_mask_from_lengths(x_lengths, x), 1).to(x.dtype) |
| |
|
| | x = self.encoder(x * x_mask, x_mask) |
| | stats = self.proj(x) * x_mask |
| |
|
| | m, logs = torch.split(stats, self.out_channels, dim=1) |
| | return x, m, logs, x_mask |
| |
|
| |
|
| | class ResidualCouplingBlock(nn.Module): |
| | def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): |
| | super().__init__() |
| | self.channels = channels |
| | self.hidden_channels = hidden_channels |
| | self.kernel_size = kernel_size |
| | self.dilation_rate = dilation_rate |
| | self.n_layers = n_layers |
| | self.n_flows = n_flows |
| | self.gin_channels = gin_channels |
| |
|
| | self.flows = nn.ModuleList() |
| | for i in range(n_flows): |
| | self.flows.append( |
| | ResidualCouplingLayer( |
| | channels, |
| | hidden_channels, |
| | kernel_size, |
| | dilation_rate, |
| | n_layers, |
| | gin_channels=gin_channels, |
| | mean_only=True, |
| | ) |
| | ) |
| | self.flows.append(Flip()) |
| |
|
| | def forward(self, x, x_mask, g=None, reverse=False): |
| | if not reverse: |
| | for flow in self.flows: |
| | x, _ = flow(x, x_mask, g=g, reverse=reverse) |
| | else: |
| | for flow in reversed(self.flows): |
| | x = flow(x, x_mask, g=g, reverse=reverse) |
| | return x |
| |
|
| |
|
| | class PosteriorEncoder(nn.Module): |
| | def __init__( |
| | self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0 |
| | ): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | self.hidden_channels = hidden_channels |
| | self.kernel_size = kernel_size |
| | self.dilation_rate = dilation_rate |
| | self.n_layers = n_layers |
| | self.gin_channels = gin_channels |
| |
|
| | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) |
| | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) |
| | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) |
| |
|
| | def forward(self, x, x_lengths, g=None): |
| | x_mask = torch.unsqueeze(get_mask_from_lengths(x_lengths, x), 1).to(x.dtype).to(device=x.device) |
| | x = self.pre(x) * x_mask |
| | x = self.enc(x, x_mask, g=g) |
| | stats = self.proj(x) * x_mask |
| | m, logs = torch.split(stats, self.out_channels, dim=1) |
| | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask |
| | return z, m, logs, x_mask |
| |
|
| |
|
| | class Generator(torch.nn.Module): |
| | def __init__( |
| | self, |
| | initial_channel, |
| | resblock, |
| | resblock_kernel_sizes, |
| | resblock_dilation_sizes, |
| | upsample_rates, |
| | upsample_initial_channel, |
| | upsample_kernel_sizes, |
| | gin_channels=0, |
| | ): |
| | super(Generator, self).__init__() |
| | self.num_kernels = len(resblock_kernel_sizes) |
| | self.num_upsamples = len(upsample_rates) |
| | self.conv_pre = nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) |
| | resblock = ResBlock1 if resblock == '1' else ResBlock2 |
| |
|
| | self.ups = nn.ModuleList() |
| | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): |
| | self.ups.append( |
| | weight_norm( |
| | nn.ConvTranspose1d( |
| | upsample_initial_channel // (2 ** i), |
| | upsample_initial_channel // (2 ** (i + 1)), |
| | k, |
| | u, |
| | padding=(k - u) // 2, |
| | ) |
| | ) |
| | ) |
| |
|
| | self.resblocks = nn.ModuleList() |
| | for i in range(len(self.ups)): |
| | ch = upsample_initial_channel // (2 ** (i + 1)) |
| | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): |
| | self.resblocks.append(resblock(ch, k, d)) |
| |
|
| | self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) |
| | self.ups.apply(init_weights) |
| |
|
| | if gin_channels != 0: |
| | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) |
| |
|
| | def forward(self, x, g=None): |
| | x = self.conv_pre(x) |
| | if g is not None: |
| | x = x + self.cond(g) |
| |
|
| | for i in range(self.num_upsamples): |
| | x = F.leaky_relu(x, LRELU_SLOPE) |
| | x = self.ups[i](x) |
| | xs = torch.zeros(x.shape, dtype=x.dtype, device=x.device) |
| | for j in range(self.num_kernels): |
| | xs += self.resblocks[i * self.num_kernels + j](x) |
| | x = xs / self.num_kernels |
| | x = F.leaky_relu(x) |
| | x = self.conv_post(x) |
| | x = torch.tanh(x) |
| |
|
| | return x |
| |
|
| | def remove_weight_norm(self): |
| | print('Removing weight norm...') |
| | for l in self.ups: |
| | remove_weight_norm(l) |
| | for l in self.resblocks: |
| | l.remove_weight_norm() |
| |
|
| |
|
| | class DiscriminatorP(torch.nn.Module): |
| | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): |
| | super(DiscriminatorP, self).__init__() |
| | self.period = period |
| | self.use_spectral_norm = use_spectral_norm |
| | norm_f = weight_norm if use_spectral_norm == False else spectral_norm |
| | self.convs = nn.ModuleList( |
| | [ |
| | norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), |
| | norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), |
| | norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), |
| | norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), |
| | norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), |
| | ] |
| | ) |
| | self.dropout = nn.Dropout(0.3) |
| | self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) |
| |
|
| | def forward(self, x): |
| | fmap = [] |
| |
|
| | |
| | b, c, t = x.shape |
| | if t % self.period != 0: |
| | n_pad = self.period - (t % self.period) |
| | x = F.pad(x, (0, n_pad), "reflect") |
| | t = t + n_pad |
| | x = x.view(b, c, t // self.period, self.period) |
| |
|
| | for l in self.convs: |
| | x = l(x) |
| | x = self.dropout(x) |
| | x = F.leaky_relu(x, LRELU_SLOPE) |
| | fmap.append(x) |
| | x = self.conv_post(x) |
| | fmap.append(x) |
| | x = torch.flatten(x, 1, -1) |
| |
|
| | return x, fmap |
| |
|
| |
|
| | class DiscriminatorS(torch.nn.Module): |
| | def __init__(self, use_spectral_norm=False): |
| | super(DiscriminatorS, self).__init__() |
| | norm_f = weight_norm if use_spectral_norm == False else spectral_norm |
| | self.convs = nn.ModuleList( |
| | [ |
| | norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)), |
| | norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), |
| | norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), |
| | norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), |
| | norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), |
| | norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), |
| | ] |
| | ) |
| | self.dropout = nn.Dropout(0.3) |
| | self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) |
| |
|
| | def forward(self, x): |
| | fmap = [] |
| |
|
| | for l in self.convs: |
| | x = l(x) |
| | x = F.leaky_relu(x, LRELU_SLOPE) |
| | fmap.append(x) |
| | x = self.conv_post(x) |
| | fmap.append(x) |
| | x = torch.flatten(x, 1, -1) |
| |
|
| | return x, fmap |
| |
|
| |
|
| | class MultiPeriodDiscriminator(torch.nn.Module): |
| | def __init__(self, use_spectral_norm=False): |
| | super(MultiPeriodDiscriminator, self).__init__() |
| | periods = [2, 3, 5, 7, 11] |
| |
|
| | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] |
| | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] |
| | self.discriminators = nn.ModuleList(discs) |
| |
|
| | def forward(self, y, y_hat): |
| | y_d_rs = [] |
| | y_d_gs = [] |
| | fmap_rs = [] |
| | fmap_gs = [] |
| | for i, d in enumerate(self.discriminators): |
| | y_d_r, fmap_r = d(y) |
| | y_d_g, fmap_g = d(y_hat) |
| | y_d_rs.append(y_d_r) |
| | y_d_gs.append(y_d_g) |
| | fmap_rs.append(fmap_r) |
| | fmap_gs.append(fmap_g) |
| |
|
| | return y_d_rs, y_d_gs, fmap_rs, fmap_gs |
| |
|
| |
|
| | class SynthesizerTrn(nn.Module): |
| | """ |
| | Synthesizer for Training |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | n_vocab, |
| | spec_channels, |
| | segment_size, |
| | inter_channels, |
| | hidden_channels, |
| | filter_channels, |
| | n_heads, |
| | n_layers, |
| | kernel_size, |
| | p_dropout, |
| | padding_idx, |
| | resblock, |
| | resblock_kernel_sizes, |
| | resblock_dilation_sizes, |
| | upsample_rates, |
| | upsample_initial_channel, |
| | upsample_kernel_sizes, |
| | n_speakers=0, |
| | gin_channels=0, |
| | use_sdp=True, |
| | **kwargs |
| | ): |
| |
|
| | super().__init__() |
| | self.n_vocab = n_vocab |
| | self.spec_channels = spec_channels |
| | self.inter_channels = inter_channels |
| | self.hidden_channels = hidden_channels |
| | self.filter_channels = filter_channels |
| | self.n_heads = n_heads |
| | self.n_layers = n_layers |
| | self.kernel_size = kernel_size |
| | self.p_dropout = p_dropout |
| | self.padding_idx = padding_idx |
| | self.resblock = resblock |
| | self.resblock_kernel_sizes = resblock_kernel_sizes |
| | self.resblock_dilation_sizes = resblock_dilation_sizes |
| | self.upsample_rates = upsample_rates |
| | self.upsample_initial_channel = upsample_initial_channel |
| | self.upsample_kernel_sizes = upsample_kernel_sizes |
| | self.segment_size = segment_size |
| | self.n_speakers = n_speakers |
| | self.gin_channels = gin_channels |
| |
|
| | self.use_sdp = use_sdp |
| |
|
| | self.enc_p = TextEncoder( |
| | n_vocab, |
| | inter_channels, |
| | hidden_channels, |
| | filter_channels, |
| | n_heads, |
| | n_layers, |
| | kernel_size, |
| | p_dropout, |
| | padding_idx, |
| | ) |
| | self.dec = Generator( |
| | inter_channels, |
| | resblock, |
| | resblock_kernel_sizes, |
| | resblock_dilation_sizes, |
| | upsample_rates, |
| | upsample_initial_channel, |
| | upsample_kernel_sizes, |
| | gin_channels=gin_channels, |
| | ) |
| | self.enc_q = PosteriorEncoder( |
| | spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels |
| | ) |
| | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) |
| |
|
| | if use_sdp: |
| | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) |
| | else: |
| | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) |
| |
|
| | if n_speakers > 1: |
| | self.emb_g = nn.Embedding(n_speakers, gin_channels) |
| |
|
| | def forward(self, text, text_len, spec, spec_len, speakers=None): |
| | x, mean_prior, logscale_prior, text_mask = self.enc_p(text, text_len) |
| | if self.n_speakers > 1: |
| | g = self.emb_g(speakers).unsqueeze(-1) |
| | else: |
| | g = None |
| |
|
| | z, mean_posterior, logscale_posterior, spec_mask = self.enc_q(spec, spec_len, g=g) |
| | z_p = self.flow(z, spec_mask, g=g) |
| |
|
| | with torch.no_grad(): |
| | |
| | s_p_sq_r = torch.exp(-2 * logscale_prior) |
| | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logscale_prior, [1], keepdim=True) |
| | neg_cent2 = torch.matmul( |
| | -0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r |
| | ) |
| | neg_cent3 = torch.matmul( |
| | z_p.transpose(1, 2), (mean_prior * s_p_sq_r) |
| | ) |
| | neg_cent4 = torch.sum(-0.5 * (mean_prior ** 2) * s_p_sq_r, [1], keepdim=True) |
| | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 |
| |
|
| | attn_mask = torch.unsqueeze(text_mask, 2) * torch.unsqueeze(spec_mask, -1) |
| | attn = maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() |
| |
|
| | w = attn.sum(2) |
| | if self.use_sdp: |
| | l_length = self.dp(x, text_mask, w, g=g) |
| | l_length = l_length / torch.sum(text_mask) |
| | else: |
| | logw_ = torch.log(w + 1e-6) * text_mask |
| | logw = self.dp(x, text_mask, g=g) |
| | l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(text_mask) |
| |
|
| | |
| | mean_prior = torch.matmul(attn.squeeze(1), mean_prior.transpose(1, 2)).transpose( |
| | 1, 2 |
| | ) |
| | logscale_prior = torch.matmul(attn.squeeze(1), logscale_prior.transpose(1, 2)).transpose( |
| | 1, 2 |
| | ) |
| |
|
| | z_slice, ids_slice = rand_slice_segments(z, spec_len, self.segment_size) |
| | audio = self.dec(z_slice, g=g) |
| | return ( |
| | audio, |
| | l_length, |
| | attn, |
| | ids_slice, |
| | text_mask, |
| | spec_mask, |
| | (z, z_p, mean_prior, logscale_prior, mean_posterior, logscale_posterior), |
| | ) |
| |
|
| | def infer(self, text, text_len, speakers=None, noise_scale=1, length_scale=1, noise_scale_w=1.0, max_len=None): |
| | x, mean_prior, logscale_prior, text_mask = self.enc_p(text, text_len) |
| | if self.n_speakers > 1 and speakers is not None: |
| | g = self.emb_g(speakers).unsqueeze(-1) |
| | else: |
| | g = None |
| |
|
| | if self.use_sdp: |
| | logw = self.dp(x, text_mask, g=g, reverse=True, noise_scale=noise_scale_w) |
| | else: |
| | logw = self.dp(x, text_mask, g=g) |
| | w = torch.exp(logw) * text_mask * length_scale |
| | w_ceil = torch.ceil(w) |
| | audio_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() |
| | audio_mask = torch.unsqueeze(get_mask_from_lengths(audio_lengths, None), 1).to(text_mask.dtype) |
| | attn_mask = torch.unsqueeze(text_mask, 2) * torch.unsqueeze(audio_mask, -1) |
| | attn = generate_path(w_ceil, attn_mask) |
| |
|
| | mean_prior = torch.matmul(attn.squeeze(1), mean_prior.transpose(1, 2)).transpose( |
| | 1, 2 |
| | ) |
| | logscale_prior = torch.matmul(attn.squeeze(1), logscale_prior.transpose(1, 2)).transpose( |
| | 1, 2 |
| | ) |
| |
|
| | z_p = mean_prior + torch.randn_like(mean_prior) * torch.exp(logscale_prior) * noise_scale |
| | z = self.flow(z_p, audio_mask, g=g, reverse=True) |
| | audio = self.dec((z * audio_mask)[:, :, :max_len], g=g) |
| | return audio, attn, audio_mask, (z, z_p, mean_prior, logscale_prior) |
| |
|
| | |
| | def voice_conversion(self, y, y_lengths, speaker_src, speaker_tgt): |
| | assert self.n_speakers > 1, "n_speakers have to be larger than 1." |
| | g_src = self.emb_g(speaker_src).unsqueeze(-1) |
| | g_tgt = self.emb_g(speaker_tgt).unsqueeze(-1) |
| | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) |
| | z_p = self.flow(z, y_mask, g=g_src) |
| | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) |
| | o_hat = self.dec(z_hat * y_mask, g=g_tgt) |
| | return o_hat, y_mask, (z, z_p, z_hat) |
| |
|
| |
|
| | |
| | |
| | |
| | class AttentionEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | hidden_channels, |
| | filter_channels, |
| | n_heads, |
| | n_layers, |
| | kernel_size=1, |
| | p_dropout=0.0, |
| | window_size=4, |
| | **kwargs |
| | ): |
| | super().__init__() |
| | self.hidden_channels = hidden_channels |
| | self.filter_channels = filter_channels |
| | self.n_heads = n_heads |
| | self.n_layers = n_layers |
| | self.kernel_size = kernel_size |
| | self.p_dropout = p_dropout |
| | self.window_size = window_size |
| |
|
| | self.drop = nn.Dropout(p_dropout) |
| | self.attn_layers = nn.ModuleList() |
| | self.norm_layers_1 = nn.ModuleList() |
| | self.ffn_layers = nn.ModuleList() |
| | self.norm_layers_2 = nn.ModuleList() |
| | for _ in range(self.n_layers): |
| | self.attn_layers.append( |
| | MultiHeadAttention( |
| | hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size |
| | ) |
| | ) |
| | self.norm_layers_1.append(LayerNorm(hidden_channels)) |
| | self.ffn_layers.append( |
| | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout) |
| | ) |
| | self.norm_layers_2.append(LayerNorm(hidden_channels)) |
| |
|
| | def forward(self, x, x_mask): |
| | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) |
| | x = x * x_mask |
| | for i in range(self.n_layers): |
| | y = self.attn_layers[i](x, x, attn_mask) |
| | y = self.drop(y) |
| | x = self.norm_layers_1[i](x + y) |
| | y = self.ffn_layers[i](x, x_mask) |
| | y = self.drop(y) |
| | x = self.norm_layers_2[i](x + y) |
| | x = x * x_mask |
| | return x |
| |
|
| |
|
| | class MultiHeadAttention(nn.Module): |
| | def __init__( |
| | self, |
| | channels, |
| | out_channels, |
| | n_heads, |
| | p_dropout=0.0, |
| | window_size=None, |
| | heads_share=True, |
| | block_length=None, |
| | proximal_bias=False, |
| | proximal_init=False, |
| | ): |
| | super().__init__() |
| | assert channels % n_heads == 0 |
| |
|
| | self.channels = channels |
| | self.out_channels = out_channels |
| | self.n_heads = n_heads |
| | self.p_dropout = p_dropout |
| | self.window_size = window_size |
| | self.heads_share = heads_share |
| | self.block_length = block_length |
| | self.proximal_bias = proximal_bias |
| | self.proximal_init = proximal_init |
| | self.attn = None |
| |
|
| | self.k_channels = channels // n_heads |
| | self.conv_q = nn.Conv1d(channels, channels, 1) |
| | self.conv_k = nn.Conv1d(channels, channels, 1) |
| | self.conv_v = nn.Conv1d(channels, channels, 1) |
| | self.conv_o = nn.Conv1d(channels, out_channels, 1) |
| | self.drop = nn.Dropout(p_dropout) |
| |
|
| | if window_size is not None: |
| | n_heads_rel = 1 if heads_share else n_heads |
| | rel_stddev = self.k_channels ** -0.5 |
| | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) |
| | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) |
| |
|
| | nn.init.xavier_uniform_(self.conv_q.weight) |
| | nn.init.xavier_uniform_(self.conv_k.weight) |
| | nn.init.xavier_uniform_(self.conv_v.weight) |
| | if proximal_init: |
| | with torch.no_grad(): |
| | self.conv_k.weight.copy_(self.conv_q.weight) |
| | self.conv_k.bias.copy_(self.conv_q.bias) |
| |
|
| | def forward(self, x, c, attn_mask=None): |
| | q = self.conv_q(x) |
| | k = self.conv_k(c) |
| | v = self.conv_v(c) |
| |
|
| | x, self.attn = self.attention(q, k, v, mask=attn_mask) |
| |
|
| | x = self.conv_o(x) |
| | return x |
| |
|
| | def attention(self, query, key, value, mask=None): |
| | |
| | b, d, t_s, t_t = key.size(0), key.size(1), key.size(2), query.size(2) |
| | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) |
| | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) |
| | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) |
| |
|
| | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) |
| | if self.window_size is not None: |
| | assert t_s == t_t, "Relative attention is only available for self-attention." |
| | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) |
| | rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) |
| | scores_local = self._relative_position_to_absolute_position(rel_logits) |
| | scores = scores + scores_local |
| | if self.proximal_bias: |
| | assert t_s == t_t, "Proximal bias is only available for self-attention." |
| | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) |
| | if mask is not None: |
| | scores = scores.masked_fill(mask == 0, -1e4) |
| | if self.block_length is not None: |
| | assert t_s == t_t, "Local attention is only available for self-attention." |
| | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) |
| | scores = scores.masked_fill(block_mask == 0, -1e4) |
| | p_attn = F.softmax(scores, dim=-1) |
| | p_attn = self.drop(p_attn) |
| | output = torch.matmul(p_attn, value) |
| | if self.window_size is not None: |
| | relative_weights = self._absolute_position_to_relative_position(p_attn) |
| | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) |
| | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) |
| | output = output.transpose(2, 3).contiguous().view(b, d, t_t) |
| | return output, p_attn |
| |
|
| | def _matmul_with_relative_values(self, x, y): |
| | """ |
| | x: [b, h, l, m] |
| | y: [h or 1, m, d] |
| | ret: [b, h, l, d] |
| | """ |
| | ret = torch.matmul(x, y.unsqueeze(0)) |
| | return ret |
| |
|
| | def _matmul_with_relative_keys(self, x, y): |
| | """ |
| | x: [b, h, l, d] |
| | y: [h or 1, m, d] |
| | ret: [b, h, l, m] |
| | """ |
| | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) |
| | return ret |
| |
|
| | def _get_relative_embeddings(self, relative_embeddings, length): |
| | |
| | pad_length = max(length - (self.window_size + 1), 0) |
| | slice_start_position = max((self.window_size + 1) - length, 0) |
| | slice_end_position = slice_start_position + 2 * length - 1 |
| | if pad_length > 0: |
| | padded_relative_embeddings = F.pad( |
| | relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]) |
| | ) |
| | else: |
| | padded_relative_embeddings = relative_embeddings |
| | used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] |
| | return used_relative_embeddings |
| |
|
| | def _relative_position_to_absolute_position(self, x): |
| | """ |
| | x: [b, h, l, 2*l-1] |
| | ret: [b, h, l, l] |
| | """ |
| | batch, heads, length, _ = x.size() |
| | |
| | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) |
| |
|
| | |
| | x_flat = x.view([batch, heads, length * 2 * length]) |
| | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) |
| |
|
| | |
| | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] |
| | return x_final |
| |
|
| | def _absolute_position_to_relative_position(self, x): |
| | """ |
| | x: [b, h, l, l] |
| | ret: [b, h, l, 2*l-1] |
| | """ |
| | batch, heads, length, _ = x.size() |
| | |
| | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) |
| | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) |
| | |
| | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) |
| | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] |
| | return x_final |
| |
|
| | def _attention_bias_proximal(self, length): |
| | """Bias for self-attention to encourage attention to close positions. |
| | Args: |
| | length: an integer scalar. |
| | Returns: |
| | a Tensor with shape [1, 1, length, length] |
| | """ |
| | r = torch.arange(length, dtype=torch.float32) |
| | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) |
| | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) |
| |
|
| |
|
| | class FFN(nn.Module): |
| | def __init__( |
| | self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False |
| | ): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | self.filter_channels = filter_channels |
| | self.kernel_size = kernel_size |
| | self.p_dropout = p_dropout |
| | self.activation = activation |
| | self.causal = causal |
| |
|
| | if causal: |
| | self.padding = self._causal_padding |
| | else: |
| | self.padding = self._same_padding |
| |
|
| | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) |
| | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) |
| | self.drop = nn.Dropout(p_dropout) |
| |
|
| | def forward(self, x, x_mask): |
| | x = self.conv_1(self.padding(x * x_mask)) |
| | if self.activation == "gelu": |
| | x = x * torch.sigmoid(1.702 * x) |
| | else: |
| | x = torch.relu(x) |
| | x = self.drop(x) |
| | x = self.conv_2(self.padding(x * x_mask)) |
| | return x * x_mask |
| |
|
| | def _causal_padding(self, x): |
| | if self.kernel_size == 1: |
| | return x |
| | pad_l = self.kernel_size - 1 |
| | pad_r = 0 |
| | padding = [[0, 0], [0, 0], [pad_l, pad_r]] |
| | x = F.pad(x, convert_pad_shape(padding)) |
| | return x |
| |
|
| | def _same_padding(self, x): |
| | if self.kernel_size == 1: |
| | return x |
| | pad_l = (self.kernel_size - 1) // 2 |
| | pad_r = self.kernel_size // 2 |
| | padding = [[0, 0], [0, 0], [pad_l, pad_r]] |
| | x = F.pad(x, convert_pad_shape(padding)) |
| | return x |
| |
|