import torch import torch.nn as nn import torch.nn.functional as F class ResidualCouplingLayer(nn.Module): def __init__(self, spec_channels, inter_channels, hidden_channels, kernel_size, enc_dilation_rate, n_layers, p_dropout): super().__init__() self.channels = channels self.hidden_channels = hidden_channels self.kernel_size = kernel_size self.dilation_rate = dilation_rate self.n_layers = n_layers self.pre = nn.Conv1d(channels // 2, hidden_channels, 1) self.convs = nn.ModuleList() for i in range(n_layers): dilation = dilation_rate ** i self.convs.append( nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=(kernel_size-1)//2 * dilation, dilation=dilation) ) self.post = nn.Conv1d(hidden_channels, channels, 1) def forward(self, x, reverse=False): x0, x1 = torch.chunk(x, 2, dim=1) h = self.pre(x0) for conv in self.convs: h = F.relu(conv(h)) h = self.post(h) m, logs = torch.chunk(h, 2, dim=1) if not reverse: x1 = m + x1 * torch.exp(logs) else: x1 = (x1 - m) * torch.exp(-logs) return torch.cat([x0, x1], dim=1) class ResidualCouplingBlock(nn.Module): def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows): super().__init__() self.flows = nn.ModuleList() for _ in range(n_flows): self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers)) self.flows.append(Flip()) def forward(self, x, reverse=False): if not reverse: for flow in self.flows: x = flow(x) else: for flow in reversed(self.flows): x = flow(x, reverse=True) return x class Flip(nn.Module): def __init__(self): super().__init__() def forward(self, x, reverse=False): return x.flip(1) class PosteriorEncoder(nn.Module): def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers): super().__init__() self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.convs = nn.ModuleList() for i in range(n_layers): dilation = dilation_rate ** i self.convs.append( nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=(kernel_size-1)//2 * dilation, dilation=dilation) ) self.proj_mean = nn.Conv1d(hidden_channels, out_channels, 1) self.proj_logvar = nn.Conv1d(hidden_channels, out_channels, 1) def forward(self, x, x_lengths): x = self.pre(x) for conv in self.convs: x = F.relu(conv(x)) m = self.proj_mean(x) logs = self.proj_logvar(x) z = m + torch.randn_like(m) * torch.exp(logs) return z, m, logs def infer(self, z, z_lengths): return z