style-bert-vits2-fastapi / src /sbv2 /posterior_encoder.py
buchi-stdesign's picture
Upload 18 files
1ee91f8 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualCouplingLayer(nn.Module):
def __init__(self, spec_channels, inter_channels, hidden_channels, kernel_size, enc_dilation_rate, n_layers, p_dropout):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.pre = nn.Conv1d(channels // 2, hidden_channels, 1)
self.convs = nn.ModuleList()
for i in range(n_layers):
dilation = dilation_rate ** i
self.convs.append(
nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=(kernel_size-1)//2 * dilation, dilation=dilation)
)
self.post = nn.Conv1d(hidden_channels, channels, 1)
def forward(self, x, reverse=False):
x0, x1 = torch.chunk(x, 2, dim=1)
h = self.pre(x0)
for conv in self.convs:
h = F.relu(conv(h))
h = self.post(h)
m, logs = torch.chunk(h, 2, dim=1)
if not reverse:
x1 = m + x1 * torch.exp(logs)
else:
x1 = (x1 - m) * torch.exp(-logs)
return torch.cat([x0, x1], dim=1)
class ResidualCouplingBlock(nn.Module):
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows):
super().__init__()
self.flows = nn.ModuleList()
for _ in range(n_flows):
self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers))
self.flows.append(Flip())
def forward(self, x, reverse=False):
if not reverse:
for flow in self.flows:
x = flow(x)
else:
for flow in reversed(self.flows):
x = flow(x, reverse=True)
return x
class Flip(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, reverse=False):
return x.flip(1)
class PosteriorEncoder(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers):
super().__init__()
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.convs = nn.ModuleList()
for i in range(n_layers):
dilation = dilation_rate ** i
self.convs.append(
nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=(kernel_size-1)//2 * dilation, dilation=dilation)
)
self.proj_mean = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj_logvar = nn.Conv1d(hidden_channels, out_channels, 1)
def forward(self, x, x_lengths):
x = self.pre(x)
for conv in self.convs:
x = F.relu(conv(x))
m = self.proj_mean(x)
logs = self.proj_logvar(x)
z = m + torch.randn_like(m) * torch.exp(logs)
return z, m, logs
def infer(self, z, z_lengths):
return z