| | import math |
| | import torch |
| | from torch import nn |
| | from typing import Optional, Any |
| | from torch import Tensor |
| | import torch.nn.functional as F |
| | import torchaudio |
| | import torchaudio.functional as audio_F |
| |
|
| | import random |
| | random.seed(0) |
| |
|
| |
|
| | def _get_activation_fn(activ): |
| | if activ == 'relu': |
| | return nn.ReLU() |
| | elif activ == 'lrelu': |
| | return nn.LeakyReLU(0.2) |
| | elif activ == 'swish': |
| | return lambda x: x*torch.sigmoid(x) |
| | else: |
| | raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ) |
| |
|
| | class LinearNorm(torch.nn.Module): |
| | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): |
| | super(LinearNorm, self).__init__() |
| | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) |
| |
|
| | torch.nn.init.xavier_uniform_( |
| | self.linear_layer.weight, |
| | gain=torch.nn.init.calculate_gain(w_init_gain)) |
| |
|
| | def forward(self, x): |
| | return self.linear_layer(x) |
| |
|
| |
|
| | class ConvNorm(torch.nn.Module): |
| | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, |
| | padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): |
| | super(ConvNorm, self).__init__() |
| | if padding is None: |
| | assert(kernel_size % 2 == 1) |
| | padding = int(dilation * (kernel_size - 1) / 2) |
| |
|
| | self.conv = torch.nn.Conv1d(in_channels, out_channels, |
| | kernel_size=kernel_size, stride=stride, |
| | padding=padding, dilation=dilation, |
| | bias=bias) |
| |
|
| | torch.nn.init.xavier_uniform_( |
| | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) |
| |
|
| | def forward(self, signal): |
| | conv_signal = self.conv(signal) |
| | return conv_signal |
| |
|
| | class CausualConv(nn.Module): |
| | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None): |
| | super(CausualConv, self).__init__() |
| | if padding is None: |
| | assert(kernel_size % 2 == 1) |
| | padding = int(dilation * (kernel_size - 1) / 2) * 2 |
| | else: |
| | self.padding = padding * 2 |
| | self.conv = nn.Conv1d(in_channels, out_channels, |
| | kernel_size=kernel_size, stride=stride, |
| | padding=self.padding, |
| | dilation=dilation, |
| | bias=bias) |
| |
|
| | torch.nn.init.xavier_uniform_( |
| | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) |
| |
|
| | def forward(self, x): |
| | x = self.conv(x) |
| | x = x[:, :, :-self.padding] |
| | return x |
| |
|
| | class CausualBlock(nn.Module): |
| | def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'): |
| | super(CausualBlock, self).__init__() |
| | self.blocks = nn.ModuleList([ |
| | self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) |
| | for i in range(n_conv)]) |
| |
|
| | def forward(self, x): |
| | for block in self.blocks: |
| | res = x |
| | x = block(x) |
| | x += res |
| | return x |
| |
|
| | def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2): |
| | layers = [ |
| | CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), |
| | _get_activation_fn(activ), |
| | nn.BatchNorm1d(hidden_dim), |
| | nn.Dropout(p=dropout_p), |
| | CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), |
| | _get_activation_fn(activ), |
| | nn.Dropout(p=dropout_p) |
| | ] |
| | return nn.Sequential(*layers) |
| |
|
| | class ConvBlock(nn.Module): |
| | def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'): |
| | super().__init__() |
| | self._n_groups = 8 |
| | self.blocks = nn.ModuleList([ |
| | self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) |
| | for i in range(n_conv)]) |
| |
|
| |
|
| | def forward(self, x): |
| | for block in self.blocks: |
| | res = x |
| | x = block(x) |
| | x += res |
| | return x |
| |
|
| | def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2): |
| | layers = [ |
| | ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), |
| | _get_activation_fn(activ), |
| | nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), |
| | nn.Dropout(p=dropout_p), |
| | ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), |
| | _get_activation_fn(activ), |
| | nn.Dropout(p=dropout_p) |
| | ] |
| | return nn.Sequential(*layers) |
| |
|
| | class LocationLayer(nn.Module): |
| | def __init__(self, attention_n_filters, attention_kernel_size, |
| | attention_dim): |
| | super(LocationLayer, self).__init__() |
| | padding = int((attention_kernel_size - 1) / 2) |
| | self.location_conv = ConvNorm(2, attention_n_filters, |
| | kernel_size=attention_kernel_size, |
| | padding=padding, bias=False, stride=1, |
| | dilation=1) |
| | self.location_dense = LinearNorm(attention_n_filters, attention_dim, |
| | bias=False, w_init_gain='tanh') |
| |
|
| | def forward(self, attention_weights_cat): |
| | processed_attention = self.location_conv(attention_weights_cat) |
| | processed_attention = processed_attention.transpose(1, 2) |
| | processed_attention = self.location_dense(processed_attention) |
| | return processed_attention |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, |
| | attention_location_n_filters, attention_location_kernel_size): |
| | super(Attention, self).__init__() |
| | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, |
| | bias=False, w_init_gain='tanh') |
| | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, |
| | w_init_gain='tanh') |
| | self.v = LinearNorm(attention_dim, 1, bias=False) |
| | self.location_layer = LocationLayer(attention_location_n_filters, |
| | attention_location_kernel_size, |
| | attention_dim) |
| | self.score_mask_value = -float("inf") |
| |
|
| | def get_alignment_energies(self, query, processed_memory, |
| | attention_weights_cat): |
| | """ |
| | PARAMS |
| | ------ |
| | query: decoder output (batch, n_mel_channels * n_frames_per_step) |
| | processed_memory: processed encoder outputs (B, T_in, attention_dim) |
| | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) |
| | RETURNS |
| | ------- |
| | alignment (batch, max_time) |
| | """ |
| |
|
| | processed_query = self.query_layer(query.unsqueeze(1)) |
| | processed_attention_weights = self.location_layer(attention_weights_cat) |
| | energies = self.v(torch.tanh( |
| | processed_query + processed_attention_weights + processed_memory)) |
| |
|
| | energies = energies.squeeze(-1) |
| | return energies |
| |
|
| | def forward(self, attention_hidden_state, memory, processed_memory, |
| | attention_weights_cat, mask): |
| | """ |
| | PARAMS |
| | ------ |
| | attention_hidden_state: attention rnn last output |
| | memory: encoder outputs |
| | processed_memory: processed encoder outputs |
| | attention_weights_cat: previous and cummulative attention weights |
| | mask: binary mask for padded data |
| | """ |
| | alignment = self.get_alignment_energies( |
| | attention_hidden_state, processed_memory, attention_weights_cat) |
| |
|
| | if mask is not None: |
| | alignment.data.masked_fill_(mask, self.score_mask_value) |
| |
|
| | attention_weights = F.softmax(alignment, dim=1) |
| | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) |
| | attention_context = attention_context.squeeze(1) |
| |
|
| | return attention_context, attention_weights |
| |
|
| |
|
| | class ForwardAttentionV2(nn.Module): |
| | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, |
| | attention_location_n_filters, attention_location_kernel_size): |
| | super(ForwardAttentionV2, self).__init__() |
| | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, |
| | bias=False, w_init_gain='tanh') |
| | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, |
| | w_init_gain='tanh') |
| | self.v = LinearNorm(attention_dim, 1, bias=False) |
| | self.location_layer = LocationLayer(attention_location_n_filters, |
| | attention_location_kernel_size, |
| | attention_dim) |
| | self.score_mask_value = -float(1e20) |
| |
|
| | def get_alignment_energies(self, query, processed_memory, |
| | attention_weights_cat): |
| | """ |
| | PARAMS |
| | ------ |
| | query: decoder output (batch, n_mel_channels * n_frames_per_step) |
| | processed_memory: processed encoder outputs (B, T_in, attention_dim) |
| | attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) |
| | RETURNS |
| | ------- |
| | alignment (batch, max_time) |
| | """ |
| |
|
| | processed_query = self.query_layer(query.unsqueeze(1)) |
| | processed_attention_weights = self.location_layer(attention_weights_cat) |
| | energies = self.v(torch.tanh( |
| | processed_query + processed_attention_weights + processed_memory)) |
| |
|
| | energies = energies.squeeze(-1) |
| | return energies |
| |
|
| | def forward(self, attention_hidden_state, memory, processed_memory, |
| | attention_weights_cat, mask, log_alpha): |
| | """ |
| | PARAMS |
| | ------ |
| | attention_hidden_state: attention rnn last output |
| | memory: encoder outputs |
| | processed_memory: processed encoder outputs |
| | attention_weights_cat: previous and cummulative attention weights |
| | mask: binary mask for padded data |
| | """ |
| | log_energy = self.get_alignment_energies( |
| | attention_hidden_state, processed_memory, attention_weights_cat) |
| |
|
| | |
| |
|
| | if mask is not None: |
| | log_energy.data.masked_fill_(mask, self.score_mask_value) |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | log_alpha_shift_padded = [] |
| | max_time = log_energy.size(1) |
| | for sft in range(2): |
| | shifted = log_alpha[:,:max_time-sft] |
| | shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value) |
| | log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) |
| |
|
| | biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2) |
| |
|
| | log_alpha_new = biased + log_energy |
| |
|
| | attention_weights = F.softmax(log_alpha_new, dim=1) |
| |
|
| | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) |
| | attention_context = attention_context.squeeze(1) |
| |
|
| | return attention_context, attention_weights, log_alpha_new |
| |
|
| |
|
| | class PhaseShuffle2d(nn.Module): |
| | def __init__(self, n=2): |
| | super(PhaseShuffle2d, self).__init__() |
| | self.n = n |
| | self.random = random.Random(1) |
| |
|
| | def forward(self, x, move=None): |
| | |
| | if move is None: |
| | move = self.random.randint(-self.n, self.n) |
| |
|
| | if move == 0: |
| | return x |
| | else: |
| | left = x[:, :, :, :move] |
| | right = x[:, :, :, move:] |
| | shuffled = torch.cat([right, left], dim=3) |
| | return shuffled |
| |
|
| | class PhaseShuffle1d(nn.Module): |
| | def __init__(self, n=2): |
| | super(PhaseShuffle1d, self).__init__() |
| | self.n = n |
| | self.random = random.Random(1) |
| |
|
| | def forward(self, x, move=None): |
| | |
| | if move is None: |
| | move = self.random.randint(-self.n, self.n) |
| |
|
| | if move == 0: |
| | return x |
| | else: |
| | left = x[:, :, :move] |
| | right = x[:, :, move:] |
| | shuffled = torch.cat([right, left], dim=2) |
| |
|
| | return shuffled |
| |
|
| | class MFCC(nn.Module): |
| | def __init__(self, n_mfcc=40, n_mels=80): |
| | super(MFCC, self).__init__() |
| | self.n_mfcc = n_mfcc |
| | self.n_mels = n_mels |
| | self.norm = 'ortho' |
| | dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) |
| | self.register_buffer('dct_mat', dct_mat) |
| |
|
| | def forward(self, mel_specgram): |
| | if len(mel_specgram.shape) == 2: |
| | mel_specgram = mel_specgram.unsqueeze(0) |
| | unsqueezed = True |
| | else: |
| | unsqueezed = False |
| | |
| | |
| | mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) |
| |
|
| | |
| | if unsqueezed: |
| | mfcc = mfcc.squeeze(0) |
| | return mfcc |
| |
|