Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| class BasicDiscriminatorBlock(nn.Module): | |
| def __init__(self, in_channel, out_channel): | |
| super(BasicDiscriminatorBlock, self).__init__() | |
| self.block = nn.Sequential( | |
| nn.utils.weight_norm(nn.Conv1d( | |
| in_channel, | |
| out_channel, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| )), | |
| nn.LeakyReLU(0.2, True), | |
| nn.utils.weight_norm(nn.Conv1d( | |
| out_channel, | |
| out_channel, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| )), | |
| nn.LeakyReLU(0.2, True), | |
| nn.utils.weight_norm(nn.Conv1d( | |
| out_channel, | |
| out_channel, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| )), | |
| nn.LeakyReLU(0.2, True), | |
| nn.utils.weight_norm(nn.Conv1d( | |
| out_channel, | |
| out_channel, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| )), | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class ResDiscriminatorBlock(nn.Module): | |
| def __init__(self, in_channel, out_channel): | |
| super(ResDiscriminatorBlock, self).__init__() | |
| self.block1 = nn.Sequential( | |
| nn.utils.weight_norm(nn.Conv1d( | |
| in_channel, | |
| out_channel, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| )), | |
| nn.LeakyReLU(0.2, True), | |
| nn.utils.weight_norm(nn.Conv1d( | |
| out_channel, | |
| out_channel, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| )), | |
| ) | |
| self.shortcut1 = nn.utils.weight_norm(nn.Conv1d( | |
| in_channel, | |
| out_channel, | |
| kernel_size=1, | |
| stride=2, | |
| )) | |
| self.block2 = nn.Sequential( | |
| nn.utils.weight_norm(nn.Conv1d( | |
| out_channel, | |
| out_channel, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| )), | |
| nn.LeakyReLU(0.2, True), | |
| nn.utils.weight_norm(nn.Conv1d( | |
| out_channel, | |
| out_channel, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| )), | |
| ) | |
| self.shortcut2 = nn.utils.weight_norm(nn.Conv1d( | |
| out_channel, | |
| out_channel, | |
| kernel_size=1, | |
| stride=1, | |
| )) | |
| def forward(self, x): | |
| x1 = self.block1(x) | |
| x1 = x1 + self.shortcut1(x) | |
| return self.block2(x1) + self.shortcut2(x1) | |
| class ResNet18Discriminator(nn.Module): | |
| def __init__(self, stft_channel, in_channel=64): | |
| super(ResNet18Discriminator, self).__init__() | |
| self.input = nn.Sequential( | |
| nn.utils.weight_norm(nn.Conv1d(stft_channel, in_channel, kernel_size=7, stride=2, padding=1, )), | |
| nn.LeakyReLU(0.2, True), | |
| ) | |
| self.df1 = BasicDiscriminatorBlock(in_channel, in_channel) | |
| self.df2 = ResDiscriminatorBlock(in_channel, in_channel * 2) | |
| self.df3 = ResDiscriminatorBlock(in_channel * 2, in_channel * 4) | |
| self.df4 = ResDiscriminatorBlock(in_channel * 4, in_channel * 8) | |
| def forward(self, x): | |
| x = self.input(x) | |
| x = self.df1(x) | |
| x = self.df2(x) | |
| x = self.df3(x) | |
| return self.df4(x) | |
| class FrequencyDiscriminator(nn.Module): | |
| def __init__(self, in_channel=64, fft_size=1024, hop_length=256, win_length=1024, window="hann_window"): | |
| super(FrequencyDiscriminator, self).__init__() | |
| self.fft_size = fft_size | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) | |
| self.stft_channel = fft_size // 2 + 1 | |
| self.resnet_disc = ResNet18Discriminator(self.stft_channel, in_channel) | |
| def forward(self, x): | |
| x_stft = torch.stft(x, self.fft_size, self.hop_length, self.win_length, self.window) | |
| real = x_stft[..., 0] | |
| imag = x_stft[..., 1] | |
| x_real = self.resnet_disc(real) | |
| x_imag = self.resnet_disc(imag) | |
| return x_real, x_imag | |