buchi-stdesign's picture
Upload 18 files
1ee91f8 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d
class ResidualCouplingLayer(nn.Module):
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.pre = nn.Conv1d(channels // 2, hidden_channels, 1)
self.convs = nn.ModuleList()
for i in range(n_layers):
dilation = dilation_rate ** i
self.convs.append(
nn.Conv1d(
hidden_channels,
hidden_channels,
kernel_size,
padding=(kernel_size - 1) * dilation // 2,
dilation=dilation
)
)
self.proj = nn.Conv1d(hidden_channels, channels, 1)
def forward(self, x, reverse=False):
x0, x1 = torch.chunk(x, 2, 1)
h = self.pre(x0)
for conv in self.convs:
h = F.relu(conv(h))
stats = self.proj(h)
m, logs = torch.chunk(stats, 2, 1)
if not reverse:
x1 = m + x1 * torch.exp(logs)
else:
x1 = (x1 - m) * torch.exp(-logs)
return torch.cat([x0, x1], 1)
class ResidualCouplingBlock(nn.Module):
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows):
super().__init__()
self.flows = nn.ModuleList()
for _ in range(n_flows):
self.flows.append(
ResidualCouplingLayer(
channels=channels,
hidden_channels=hidden_channels,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
n_layers=n_layers
)
)
def forward(self, x, reverse=False):
if not reverse:
for flow in self.flows:
x = flow(x, reverse=False)
else:
for flow in reversed(self.flows):
x = flow(x, reverse=True)
return x
class Flip(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, reverse=False):
if not reverse:
return torch.flip(x, [1])
else:
return torch.flip(x, [1])