import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Conv1d class ResidualCouplingLayer(nn.Module): def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers): super().__init__() self.channels = channels self.hidden_channels = hidden_channels self.kernel_size = kernel_size self.dilation_rate = dilation_rate self.n_layers = n_layers self.pre = nn.Conv1d(channels // 2, hidden_channels, 1) self.convs = nn.ModuleList() for i in range(n_layers): dilation = dilation_rate ** i self.convs.append( nn.Conv1d( hidden_channels, hidden_channels, kernel_size, padding=(kernel_size - 1) * dilation // 2, dilation=dilation ) ) self.proj = nn.Conv1d(hidden_channels, channels, 1) def forward(self, x, reverse=False): x0, x1 = torch.chunk(x, 2, 1) h = self.pre(x0) for conv in self.convs: h = F.relu(conv(h)) stats = self.proj(h) m, logs = torch.chunk(stats, 2, 1) if not reverse: x1 = m + x1 * torch.exp(logs) else: x1 = (x1 - m) * torch.exp(-logs) return torch.cat([x0, x1], 1) class ResidualCouplingBlock(nn.Module): def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows): super().__init__() self.flows = nn.ModuleList() for _ in range(n_flows): self.flows.append( ResidualCouplingLayer( channels=channels, hidden_channels=hidden_channels, kernel_size=kernel_size, dilation_rate=dilation_rate, n_layers=n_layers ) ) def forward(self, x, reverse=False): if not reverse: for flow in self.flows: x = flow(x, reverse=False) else: for flow in reversed(self.flows): x = flow(x, reverse=True) return x class Flip(nn.Module): def __init__(self): super().__init__() def forward(self, x, reverse=False): if not reverse: return torch.flip(x, [1]) else: return torch.flip(x, [1])