| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | from . import attentions |
| | from torch import nn |
| | import torch |
| | from torch.nn import functional as F |
| |
|
| |
|
| | class Mish(nn.Module): |
| | def __init__(self): |
| | super(Mish, self).__init__() |
| |
|
| | def forward(self, x): |
| | return x * torch.tanh(F.softplus(x)) |
| |
|
| |
|
| | class Conv1dGLU(nn.Module): |
| | """ |
| | Conv1d + GLU(Gated Linear Unit) with residual connection. |
| | For GLU refer to https://arxiv.org/abs/1612.08083 paper. |
| | """ |
| |
|
| | def __init__(self, in_channels, out_channels, kernel_size, dropout): |
| | super(Conv1dGLU, self).__init__() |
| | self.out_channels = out_channels |
| | self.conv1 = nn.Conv1d( |
| | in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2 |
| | ) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x): |
| | residual = x |
| | x = self.conv1(x) |
| | x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1) |
| | x = x1 * torch.sigmoid(x2) |
| | x = residual + self.dropout(x) |
| | return x |
| |
|
| |
|
| | class StyleEncoder(torch.nn.Module): |
| | def __init__(self, in_dim=513, hidden_dim=128, out_dim=256): |
| |
|
| | super().__init__() |
| |
|
| | self.in_dim = in_dim |
| | self.hidden_dim = hidden_dim |
| | self.out_dim = out_dim |
| | self.kernel_size = 5 |
| | self.n_head = 2 |
| | self.dropout = 0.1 |
| |
|
| | self.spectral = nn.Sequential( |
| | nn.Conv1d(self.in_dim, self.hidden_dim, 1), |
| | Mish(), |
| | nn.Dropout(self.dropout), |
| | nn.Conv1d(self.hidden_dim, self.hidden_dim, 1), |
| | Mish(), |
| | nn.Dropout(self.dropout), |
| | ) |
| |
|
| | self.temporal = nn.Sequential( |
| | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), |
| | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), |
| | ) |
| |
|
| | self.slf_attn = attentions.MultiHeadAttention( |
| | self.hidden_dim, |
| | self.hidden_dim, |
| | self.n_head, |
| | p_dropout=self.dropout, |
| | proximal_bias=False, |
| | proximal_init=True, |
| | ) |
| | self.atten_drop = nn.Dropout(self.dropout) |
| | self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1) |
| |
|
| | def forward(self, x, mask=None): |
| |
|
| | |
| | x = self.spectral(x) * mask |
| | |
| | x = self.temporal(x) * mask |
| |
|
| | |
| | attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1) |
| | y = self.slf_attn(x, x, attn_mask=attn_mask) |
| | x = x + self.atten_drop(y) |
| |
|
| | |
| | x = self.fc(x) |
| |
|
| | |
| | w = self.temporal_avg_pool(x, mask=mask) |
| |
|
| | return w |
| |
|
| | def temporal_avg_pool(self, x, mask=None): |
| | if mask is None: |
| | out = torch.mean(x, dim=2) |
| | else: |
| | len_ = mask.sum(dim=2) |
| | x = x.sum(dim=2) |
| |
|
| | out = torch.div(x, len_) |
| | return out |
| |
|