Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import Conv1d | |
| class ResidualCouplingLayer(nn.Module): | |
| def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers): | |
| super().__init__() | |
| self.channels = channels | |
| self.hidden_channels = hidden_channels | |
| self.kernel_size = kernel_size | |
| self.dilation_rate = dilation_rate | |
| self.n_layers = n_layers | |
| self.pre = nn.Conv1d(channels // 2, hidden_channels, 1) | |
| self.convs = nn.ModuleList() | |
| for i in range(n_layers): | |
| dilation = dilation_rate ** i | |
| self.convs.append( | |
| nn.Conv1d( | |
| hidden_channels, | |
| hidden_channels, | |
| kernel_size, | |
| padding=(kernel_size - 1) * dilation // 2, | |
| dilation=dilation | |
| ) | |
| ) | |
| self.proj = nn.Conv1d(hidden_channels, channels, 1) | |
| def forward(self, x, reverse=False): | |
| x0, x1 = torch.chunk(x, 2, 1) | |
| h = self.pre(x0) | |
| for conv in self.convs: | |
| h = F.relu(conv(h)) | |
| stats = self.proj(h) | |
| m, logs = torch.chunk(stats, 2, 1) | |
| if not reverse: | |
| x1 = m + x1 * torch.exp(logs) | |
| else: | |
| x1 = (x1 - m) * torch.exp(-logs) | |
| return torch.cat([x0, x1], 1) | |
| class ResidualCouplingBlock(nn.Module): | |
| def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows): | |
| super().__init__() | |
| self.flows = nn.ModuleList() | |
| for _ in range(n_flows): | |
| self.flows.append( | |
| ResidualCouplingLayer( | |
| channels=channels, | |
| hidden_channels=hidden_channels, | |
| kernel_size=kernel_size, | |
| dilation_rate=dilation_rate, | |
| n_layers=n_layers | |
| ) | |
| ) | |
| def forward(self, x, reverse=False): | |
| if not reverse: | |
| for flow in self.flows: | |
| x = flow(x, reverse=False) | |
| else: | |
| for flow in reversed(self.flows): | |
| x = flow(x, reverse=True) | |
| return x | |
| class Flip(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x, reverse=False): | |
| if not reverse: | |
| return torch.flip(x, [1]) | |
| else: | |
| return torch.flip(x, [1]) | |