| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | import torch.nn as nn |
| |
|
| | from sparktts.modules.blocks.layers import ( |
| | Snake1d, |
| | WNConv1d, |
| | ResidualUnit, |
| | WNConvTranspose1d, |
| | init_weights, |
| | ) |
| |
|
| |
|
| | class DecoderBlock(nn.Module): |
| | def __init__( |
| | self, |
| | input_dim: int = 16, |
| | output_dim: int = 8, |
| | kernel_size: int = 2, |
| | stride: int = 1, |
| | ): |
| | super().__init__() |
| | self.block = nn.Sequential( |
| | Snake1d(input_dim), |
| | WNConvTranspose1d( |
| | input_dim, |
| | output_dim, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=(kernel_size - stride) // 2, |
| | ), |
| | ResidualUnit(output_dim, dilation=1), |
| | ResidualUnit(output_dim, dilation=3), |
| | ResidualUnit(output_dim, dilation=9), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.block(x) |
| |
|
| |
|
| | class WaveGenerator(nn.Module): |
| | def __init__( |
| | self, |
| | input_channel, |
| | channels, |
| | rates, |
| | kernel_sizes, |
| | d_out: int = 1, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] |
| |
|
| | |
| | for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)): |
| | input_dim = channels // 2**i |
| | output_dim = channels // 2 ** (i + 1) |
| | layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)] |
| |
|
| | |
| | layers += [ |
| | Snake1d(output_dim), |
| | WNConv1d(output_dim, d_out, kernel_size=7, padding=3), |
| | nn.Tanh(), |
| | ] |
| |
|
| | self.model = nn.Sequential(*layers) |
| |
|
| | self.apply(init_weights) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|