| | import math |
| | import time |
| | from collections import OrderedDict |
| | from typing import Dict, List, Optional, Tuple |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from espnet2.torch_utils.get_layer_from_string import get_layer |
| | from torch.nn import init |
| | from torch.nn.parameter import Parameter |
| | import src.utils as utils |
| |
|
| |
|
| | class Lambda(nn.Module): |
| | def __init__(self, lambd): |
| | super().__init__() |
| | import types |
| |
|
| | assert type(lambd) is types.LambdaType |
| | self.lambd = lambd |
| |
|
| | def forward(self, x): |
| | return self.lambd(x) |
| |
|
| |
|
| | class LayerNormPermuted(nn.LayerNorm): |
| | def __init__(self, *args, **kwargs): |
| | super(LayerNormPermuted, self).__init__(*args, **kwargs) |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | x: [B, C, T, F] |
| | """ |
| | x = x.permute(0, 2, 3, 1) |
| | x = super().forward(x) |
| | x = x.permute(0, 3, 1, 2) |
| | return x |
| |
|
| |
|
| | |
| | class LayerNormalization4D(nn.Module): |
| | def __init__(self, C, eps=1e-5, preserve_outdim=False): |
| | super().__init__() |
| | self.norm = nn.LayerNorm(C, eps=eps) |
| | self.preserve_outdim = preserve_outdim |
| |
|
| | def forward(self, x: torch.Tensor): |
| | """ |
| | input: (*, C) |
| | """ |
| | x = self.norm(x) |
| | return x |
| |
|
| |
|
| | class LayerNormalization4DCF(nn.Module): |
| | def __init__(self, input_dimension, eps=1e-5): |
| | assert len(input_dimension) == 2 |
| | Q, C = input_dimension |
| | super().__init__() |
| | self.norm = nn.LayerNorm((Q * C), eps=eps) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | """ |
| | input: (B, T, Q * C) |
| | """ |
| | x = self.norm(x) |
| |
|
| | return x |
| |
|
| |
|
| | class LayerNormalization4D_old(nn.Module): |
| | def __init__(self, input_dimension, eps=1e-5): |
| | super().__init__() |
| | param_size = [1, input_dimension, 1, 1] |
| | self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) |
| | self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) |
| | init.ones_(self.gamma) |
| | init.zeros_(self.beta) |
| | self.eps = eps |
| |
|
| | def forward(self, x): |
| | if x.ndim == 4: |
| | _, C, _, _ = x.shape |
| | stat_dim = (1,) |
| | else: |
| | raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim)) |
| | mu_ = x.mean(dim=stat_dim, keepdim=True) |
| | std_ = torch.sqrt(x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) |
| | x_hat = ((x - mu_) / std_) * self.gamma + self.beta |
| | return x_hat |
| |
|
| |
|
| | def mod_pad(x, chunk_size, pad): |
| | |
| | |
| | mod = 0 |
| | if (x.shape[-1] % chunk_size) != 0: |
| | mod = chunk_size - (x.shape[-1] % chunk_size) |
| |
|
| | x = F.pad(x, (0, mod)) |
| | x = F.pad(x, pad) |
| |
|
| | return x, mod |
| |
|
| |
|
| | class Attention_STFT_causal(nn.Module): |
| | def __getitem__(self, key): |
| | return getattr(self, key) |
| |
|
| | def __init__( |
| | self, |
| | emb_dim, |
| | n_freqs, |
| | approx_qk_dim=512, |
| | n_head=4, |
| | activation="prelu", |
| | eps=1e-5, |
| | skip_conn=True, |
| | use_flash_attention=False, |
| | dim_feedforward=-1, |
| | local_context_len=-1, |
| | |
| | ): |
| | super().__init__() |
| | self.position_code = utils.PositionalEncoding(emb_dim * n_freqs, max_len=5000) |
| |
|
| | self.skip_conn = skip_conn |
| | self.n_freqs = n_freqs |
| | self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) |
| | self.n_head = n_head |
| | self.V_dim = emb_dim // n_head |
| | self.emb_dim = emb_dim |
| | assert emb_dim % n_head == 0 |
| | E = self.E |
| |
|
| | self.use_flash_attention = use_flash_attention |
| |
|
| | self.local_context_len = local_context_len |
| |
|
| | self.add_module( |
| | "attn_conv_Q", |
| | nn.Sequential( |
| | nn.Linear(emb_dim, E * n_head), |
| | get_layer(activation)(), |
| | |
| | Lambda( |
| | lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E) |
| | .permute(0, 3, 1, 2, 4) |
| | .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E) |
| | ), |
| | LayerNormalization4DCF((n_freqs, E), eps=eps), |
| | ), |
| | ) |
| | self.add_module( |
| | "attn_conv_K", |
| | nn.Sequential( |
| | nn.Linear(emb_dim, E * n_head), |
| | get_layer(activation)(), |
| | Lambda( |
| | lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E) |
| | .permute(0, 3, 1, 2, 4) |
| | .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E) |
| | ), |
| | LayerNormalization4DCF((n_freqs, E), eps=eps), |
| | ), |
| | ) |
| | self.add_module( |
| | "attn_conv_V", |
| | nn.Sequential( |
| | nn.Linear(emb_dim, (emb_dim // n_head) * n_head), |
| | get_layer(activation)(), |
| | Lambda( |
| | lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, (emb_dim // n_head)) |
| | .permute(0, 3, 1, 2, 4) |
| | .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * (emb_dim // n_head)) |
| | ), |
| | LayerNormalization4DCF((n_freqs, emb_dim // n_head), eps=eps), |
| | ), |
| | ) |
| |
|
| | self.dim_feedforward = dim_feedforward |
| |
|
| | if dim_feedforward == -1: |
| | self.add_module( |
| | "attn_concat_proj", |
| | nn.Sequential( |
| | nn.Linear(emb_dim, emb_dim), |
| | get_layer(activation)(), |
| | Lambda(lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])), |
| | LayerNormalization4DCF((n_freqs, emb_dim), eps=eps), |
| | ), |
| | ) |
| | else: |
| | self.linear1 = nn.Linear(emb_dim, dim_feedforward) |
| | self.dropout = nn.Dropout(p=0.1) |
| | self.activation = nn.ReLU() |
| | self.linear2 = nn.Linear(dim_feedforward, emb_dim) |
| | self.dropout2 = nn.Dropout(p=0.1) |
| | self.norm = LayerNormalization4DCF((n_freqs, emb_dim), eps=eps) |
| |
|
| | def _ff_block(self, x): |
| | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
| | return self.dropout2(x) |
| |
|
| | def get_lookahead_mask(self, seq_len, device): |
| | |
| | if self.local_context_len == -1: |
| | mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1) |
| |
|
| | return mask.detach().to(device) |
| |
|
| | else: |
| | mask1 = torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1 |
| | mask2 = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=self.local_context_len) == 0 |
| | mask = (mask1 * mask2).transpose(0, 1) |
| |
|
| | return mask.detach().to(device) |
| |
|
| | def forward(self, batch): |
| | |
| | |
| | inputs = batch |
| | B0, T0, Q0, C0 = batch.shape |
| |
|
| | |
| | pos_code = self.position_code(batch) |
| | _, T, QC = pos_code.shape |
| | pos_code = pos_code.reshape(1, T, Q0, C0) |
| | batch = batch + pos_code |
| |
|
| | Q = self["attn_conv_Q"](batch) |
| | K = self["attn_conv_K"](batch) |
| | V = self["attn_conv_V"](batch) |
| |
|
| | emb_dim = Q.shape[-1] |
| |
|
| | local_mask = self.get_lookahead_mask(batch.shape[1], batch.device) |
| |
|
| | attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) |
| | attn_mat.masked_fill_(local_mask == 0, -float("Inf")) |
| | attn_mat = F.softmax(attn_mat, dim=2) |
| |
|
| | V = torch.matmul(attn_mat, V) |
| | V = V.reshape(-1, T0, V.shape[-1]) |
| | V = V.transpose(1, 2) |
| |
|
| | batch = V.reshape(B0, self.n_head, self.n_freqs, self.V_dim, T0) |
| | batch = batch.transpose(2, 3) |
| | batch = batch.reshape(B0, self.n_head * self.V_dim, self.n_freqs, T0) |
| | batch = batch.permute(0, 3, 2, 1) |
| |
|
| | if self.dim_feedforward == -1: |
| | batch = self["attn_concat_proj"](batch) |
| | else: |
| | batch = batch + self._ff_block(batch) |
| | batch = batch.reshape(batch.shape[0], batch.shape[1], batch.shape[2] * batch.shape[3]) |
| | batch = self.norm(batch) |
| | batch = batch.reshape(batch.shape[0], batch.shape[1], Q0, C0) |
| |
|
| | |
| | if self.skip_conn: |
| | return batch + inputs |
| | else: |
| | return batch |
| |
|
| |
|
| | class GridNetBlock(nn.Module): |
| | def __getitem__(self, key): |
| | return getattr(self, key) |
| |
|
| | def __init__( |
| | self, |
| | emb_dim, |
| | emb_ks, |
| | emb_hs, |
| | n_freqs, |
| | hidden_channels, |
| | lstm_fold_chunk, |
| | n_head=4, |
| | approx_qk_dim=512, |
| | activation="prelu", |
| | eps=1e-5, |
| | pool="mean", |
| | last=False, |
| | local_context_len=-1, |
| | |
| | ): |
| | super().__init__() |
| | bidirectional = True |
| |
|
| | self.global_atten_causal = True |
| |
|
| | self.last = last |
| |
|
| | self.pool = pool |
| |
|
| | self.lstm_fold_chunk = lstm_fold_chunk |
| | self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) |
| |
|
| | self.V_dim = emb_dim // n_head |
| | self.H = hidden_channels |
| | in_channels = emb_dim * emb_ks |
| | self.in_channels = in_channels |
| | self.n_freqs = n_freqs |
| |
|
| | |
| | self.intra_norm = LayerNormalization4D_old(emb_dim, eps=eps) |
| | self.intra_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=True) |
| | self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs) |
| | self.emb_dim = emb_dim |
| | self.emb_ks = emb_ks |
| | self.emb_hs = emb_hs |
| |
|
| | |
| | self.inter_norm = LayerNormalization4D_old(emb_dim, eps=eps) |
| | self.inter_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=bidirectional) |
| | self.inter_linear = nn.ConvTranspose1d(hidden_channels * (bidirectional + 1), emb_dim, emb_ks, stride=emb_hs) |
| |
|
| | |
| | self.pool_atten_causal = Attention_STFT_causal( |
| | emb_dim=emb_dim, |
| | n_freqs=n_freqs, |
| | approx_qk_dim=approx_qk_dim, |
| | n_head=n_head, |
| | activation=activation, |
| | eps=eps, |
| | local_context_len=local_context_len, |
| | ) |
| |
|
| | def _unfold_timedomain(self, x): |
| | BQ, C, T = x.shape |
| | x = torch.split(x, self.lstm_fold_chunk, dim=-1) |
| | x = torch.cat(x, dim=0).reshape(-1, BQ, C, self.lstm_fold_chunk) |
| | x = x.permute(1, 0, 3, 2) |
| | return x |
| |
|
| | def forward(self, x, init_state=None): |
| | """GridNetBlock Forward. |
| | |
| | Args: |
| | x: [B, C, T, Q] |
| | out: [B, C, T, Q] |
| | """ |
| | B, C, old_T, old_Q = x.shape |
| | T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks |
| | Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks |
| | x = F.pad(x, (0, Q - old_Q, 0, T - old_T)) |
| |
|
| | |
| | |
| | input_ = x |
| | intra_rnn = self.intra_norm(input_) |
| | intra_rnn = intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q) |
| |
|
| | intra_rnn = torch.split(intra_rnn, self.emb_ks, dim=-1) |
| | intra_rnn = torch.stack(intra_rnn, dim=0) |
| | intra_rnn = intra_rnn.permute(1, 2, 3, 0).flatten(1, 2) |
| | intra_rnn = intra_rnn.transpose(1, 2) |
| | self.intra_rnn.flatten_parameters() |
| |
|
| | |
| | intra_rnn, _ = self.intra_rnn(intra_rnn) |
| | intra_rnn = intra_rnn.transpose(1, 2) |
| | intra_rnn = self.intra_linear(intra_rnn) |
| | intra_rnn = intra_rnn.view([B, T, C, Q]) |
| | intra_rnn = intra_rnn.transpose(1, 2).contiguous() |
| | intra_rnn = intra_rnn + input_ |
| | intra_rnn = intra_rnn[:, :, :, :old_Q] |
| | Q = old_Q |
| | |
| |
|
| |
|
| | |
| | |
| | inter_rnn = self.inter_norm(intra_rnn) |
| | inter_rnn = inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T) |
| |
|
| |
|
| | inter_rnn = self._unfold_timedomain(inter_rnn) |
| |
|
| | BQ, NUM_CHUNK, CHUNKSIZE, C = inter_rnn.shape |
| |
|
| | inter_rnn = inter_rnn.reshape(BQ * NUM_CHUNK, CHUNKSIZE, C) |
| | inter_rnn = inter_rnn.transpose(2, 1) |
| | input_ = inter_rnn |
| |
|
| | inter_rnn = torch.split(inter_rnn, self.emb_ks, dim=-1) |
| |
|
| | inter_rnn = torch.stack(inter_rnn, dim=0) |
| | inter_rnn = inter_rnn.permute(1, 2, 3, 0) |
| |
|
| | BF, C, EO, _T = inter_rnn.shape |
| | inter_rnn = inter_rnn.reshape(BF, C * EO, _T) |
| |
|
| | inter_rnn = inter_rnn.transpose(1, 2) |
| |
|
| | self.inter_rnn.flatten_parameters() |
| | inter_rnn, _ = self.inter_rnn(inter_rnn) |
| | inter_rnn = inter_rnn.transpose(1, 2) |
| | inter_rnn = self.inter_linear(inter_rnn) |
| | inter_rnn = inter_rnn + input_ |
| |
|
| | inter_rnn = inter_rnn.reshape(B, Q, NUM_CHUNK, C, CHUNKSIZE) |
| | inter_rnn = inter_rnn.permute(0, 1, 2, 4, 3) |
| |
|
| | input_ = inter_rnn |
| | if self.pool == "mean": |
| | inter_rnn = torch.mean(inter_rnn, dim=3) |
| | elif self.pool == "max": |
| | inter_rnn, _ = torch.max(inter_rnn, dim=3) |
| | else: |
| | raise ValueError("INvalid pool type!") |
| | |
| |
|
| | |
| | inter_rnn = inter_rnn.transpose(1, 2) |
| | inter_rnn = self.pool_atten_causal(inter_rnn) |
| | inter_rnn = inter_rnn.transpose(1, 2) |
| |
|
| | if self.last == True: |
| | return inter_rnn, init_state |
| |
|
| | else: |
| | inter_rnn = inter_rnn.unsqueeze(3) |
| | inter_rnn = input_ + inter_rnn |
| |
|
| | inter_rnn = inter_rnn.reshape(B, Q, T, C) |
| | inter_rnn = inter_rnn.permute(0, 3, 2, 1) |
| | inter_rnn = inter_rnn[..., :old_T, :] |
| | |
| |
|
| | return inter_rnn, init_state |
| |
|