Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ResidualCouplingLayer(nn.Module): | |
| def __init__(self, spec_channels, inter_channels, hidden_channels, kernel_size, enc_dilation_rate, n_layers, p_dropout): | |
| super().__init__() | |
| self.channels = channels | |
| self.hidden_channels = hidden_channels | |
| self.kernel_size = kernel_size | |
| self.dilation_rate = dilation_rate | |
| self.n_layers = n_layers | |
| self.pre = nn.Conv1d(channels // 2, hidden_channels, 1) | |
| self.convs = nn.ModuleList() | |
| for i in range(n_layers): | |
| dilation = dilation_rate ** i | |
| self.convs.append( | |
| nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=(kernel_size-1)//2 * dilation, dilation=dilation) | |
| ) | |
| self.post = nn.Conv1d(hidden_channels, channels, 1) | |
| def forward(self, x, reverse=False): | |
| x0, x1 = torch.chunk(x, 2, dim=1) | |
| h = self.pre(x0) | |
| for conv in self.convs: | |
| h = F.relu(conv(h)) | |
| h = self.post(h) | |
| m, logs = torch.chunk(h, 2, dim=1) | |
| if not reverse: | |
| x1 = m + x1 * torch.exp(logs) | |
| else: | |
| x1 = (x1 - m) * torch.exp(-logs) | |
| return torch.cat([x0, x1], dim=1) | |
| class ResidualCouplingBlock(nn.Module): | |
| def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows): | |
| super().__init__() | |
| self.flows = nn.ModuleList() | |
| for _ in range(n_flows): | |
| self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers)) | |
| self.flows.append(Flip()) | |
| def forward(self, x, reverse=False): | |
| if not reverse: | |
| for flow in self.flows: | |
| x = flow(x) | |
| else: | |
| for flow in reversed(self.flows): | |
| x = flow(x, reverse=True) | |
| return x | |
| class Flip(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x, reverse=False): | |
| return x.flip(1) | |
| class PosteriorEncoder(nn.Module): | |
| def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers): | |
| super().__init__() | |
| self.pre = nn.Conv1d(in_channels, hidden_channels, 1) | |
| self.convs = nn.ModuleList() | |
| for i in range(n_layers): | |
| dilation = dilation_rate ** i | |
| self.convs.append( | |
| nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=(kernel_size-1)//2 * dilation, dilation=dilation) | |
| ) | |
| self.proj_mean = nn.Conv1d(hidden_channels, out_channels, 1) | |
| self.proj_logvar = nn.Conv1d(hidden_channels, out_channels, 1) | |
| def forward(self, x, x_lengths): | |
| x = self.pre(x) | |
| for conv in self.convs: | |
| x = F.relu(conv(x)) | |
| m = self.proj_mean(x) | |
| logs = self.proj_logvar(x) | |
| z = m + torch.randn_like(m) * torch.exp(logs) | |
| return z, m, logs | |
| def infer(self, z, z_lengths): | |
| return z | |