|
|
import math |
|
|
import time |
|
|
from collections import OrderedDict |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from espnet2.torch_utils.get_layer_from_string import get_layer |
|
|
from torch.nn import init |
|
|
from torch.nn.parameter import Parameter |
|
|
import src.utils as utils |
|
|
|
|
|
|
|
|
class Lambda(nn.Module): |
|
|
def __init__(self, lambd): |
|
|
super().__init__() |
|
|
import types |
|
|
|
|
|
assert type(lambd) is types.LambdaType |
|
|
self.lambd = lambd |
|
|
|
|
|
def forward(self, x): |
|
|
return self.lambd(x) |
|
|
|
|
|
|
|
|
class LayerNormPermuted(nn.LayerNorm): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super(LayerNormPermuted, self).__init__(*args, **kwargs) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: [B, C, T, F] |
|
|
""" |
|
|
x = x.permute(0, 2, 3, 1) |
|
|
x = super().forward(x) |
|
|
x = x.permute(0, 3, 1, 2) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class LayerNormalization4D(nn.Module): |
|
|
def __init__(self, C, eps=1e-5, preserve_outdim=False): |
|
|
super().__init__() |
|
|
self.norm = nn.LayerNorm(C, eps=eps) |
|
|
self.preserve_outdim = preserve_outdim |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
""" |
|
|
input: (*, C) |
|
|
""" |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class LayerNormalization4DCF(nn.Module): |
|
|
def __init__(self, input_dimension, eps=1e-5): |
|
|
assert len(input_dimension) == 2 |
|
|
Q, C = input_dimension |
|
|
super().__init__() |
|
|
self.norm = nn.LayerNorm((Q * C), eps=eps) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
""" |
|
|
input: (B, T, Q * C) |
|
|
""" |
|
|
x = self.norm(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class LayerNormalization4D_old(nn.Module): |
|
|
def __init__(self, input_dimension, eps=1e-5): |
|
|
super().__init__() |
|
|
param_size = [1, input_dimension, 1, 1] |
|
|
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) |
|
|
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) |
|
|
init.ones_(self.gamma) |
|
|
init.zeros_(self.beta) |
|
|
self.eps = eps |
|
|
|
|
|
def forward(self, x): |
|
|
if x.ndim == 4: |
|
|
_, C, _, _ = x.shape |
|
|
stat_dim = (1,) |
|
|
else: |
|
|
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim)) |
|
|
mu_ = x.mean(dim=stat_dim, keepdim=True) |
|
|
std_ = torch.sqrt(x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) |
|
|
x_hat = ((x - mu_) / std_) * self.gamma + self.beta |
|
|
return x_hat |
|
|
|
|
|
|
|
|
def mod_pad(x, chunk_size, pad): |
|
|
|
|
|
|
|
|
mod = 0 |
|
|
if (x.shape[-1] % chunk_size) != 0: |
|
|
mod = chunk_size - (x.shape[-1] % chunk_size) |
|
|
|
|
|
x = F.pad(x, (0, mod)) |
|
|
x = F.pad(x, pad) |
|
|
|
|
|
return x, mod |
|
|
|
|
|
|
|
|
class Attention_STFT_causal(nn.Module): |
|
|
def __getitem__(self, key): |
|
|
return getattr(self, key) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
emb_dim, |
|
|
n_freqs, |
|
|
approx_qk_dim=512, |
|
|
n_head=4, |
|
|
activation="prelu", |
|
|
eps=1e-5, |
|
|
skip_conn=True, |
|
|
use_flash_attention=False, |
|
|
dim_feedforward=-1, |
|
|
local_context_len=-1, |
|
|
|
|
|
): |
|
|
super().__init__() |
|
|
self.position_code = utils.PositionalEncoding(emb_dim * n_freqs, max_len=5000) |
|
|
|
|
|
self.skip_conn = skip_conn |
|
|
self.n_freqs = n_freqs |
|
|
self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) |
|
|
self.n_head = n_head |
|
|
self.V_dim = emb_dim // n_head |
|
|
self.emb_dim = emb_dim |
|
|
assert emb_dim % n_head == 0 |
|
|
E = self.E |
|
|
|
|
|
self.use_flash_attention = use_flash_attention |
|
|
|
|
|
self.local_context_len = local_context_len |
|
|
|
|
|
self.add_module( |
|
|
"attn_conv_Q", |
|
|
nn.Sequential( |
|
|
nn.Linear(emb_dim, E * n_head), |
|
|
get_layer(activation)(), |
|
|
|
|
|
Lambda( |
|
|
lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E) |
|
|
.permute(0, 3, 1, 2, 4) |
|
|
.reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E) |
|
|
), |
|
|
LayerNormalization4DCF((n_freqs, E), eps=eps), |
|
|
), |
|
|
) |
|
|
self.add_module( |
|
|
"attn_conv_K", |
|
|
nn.Sequential( |
|
|
nn.Linear(emb_dim, E * n_head), |
|
|
get_layer(activation)(), |
|
|
Lambda( |
|
|
lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E) |
|
|
.permute(0, 3, 1, 2, 4) |
|
|
.reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E) |
|
|
), |
|
|
LayerNormalization4DCF((n_freqs, E), eps=eps), |
|
|
), |
|
|
) |
|
|
self.add_module( |
|
|
"attn_conv_V", |
|
|
nn.Sequential( |
|
|
nn.Linear(emb_dim, (emb_dim // n_head) * n_head), |
|
|
get_layer(activation)(), |
|
|
Lambda( |
|
|
lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, (emb_dim // n_head)) |
|
|
.permute(0, 3, 1, 2, 4) |
|
|
.reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * (emb_dim // n_head)) |
|
|
), |
|
|
LayerNormalization4DCF((n_freqs, emb_dim // n_head), eps=eps), |
|
|
), |
|
|
) |
|
|
|
|
|
self.dim_feedforward = dim_feedforward |
|
|
|
|
|
if dim_feedforward == -1: |
|
|
self.add_module( |
|
|
"attn_concat_proj", |
|
|
nn.Sequential( |
|
|
nn.Linear(emb_dim, emb_dim), |
|
|
get_layer(activation)(), |
|
|
Lambda(lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])), |
|
|
LayerNormalization4DCF((n_freqs, emb_dim), eps=eps), |
|
|
), |
|
|
) |
|
|
else: |
|
|
self.linear1 = nn.Linear(emb_dim, dim_feedforward) |
|
|
self.dropout = nn.Dropout(p=0.1) |
|
|
self.activation = nn.ReLU() |
|
|
self.linear2 = nn.Linear(dim_feedforward, emb_dim) |
|
|
self.dropout2 = nn.Dropout(p=0.1) |
|
|
self.norm = LayerNormalization4DCF((n_freqs, emb_dim), eps=eps) |
|
|
|
|
|
def _ff_block(self, x): |
|
|
x = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
|
|
return self.dropout2(x) |
|
|
|
|
|
def get_lookahead_mask(self, seq_len, device): |
|
|
|
|
|
if self.local_context_len == -1: |
|
|
mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1) |
|
|
|
|
|
return mask.detach().to(device) |
|
|
|
|
|
else: |
|
|
mask1 = torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1 |
|
|
mask2 = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=self.local_context_len) == 0 |
|
|
mask = (mask1 * mask2).transpose(0, 1) |
|
|
|
|
|
return mask.detach().to(device) |
|
|
|
|
|
def forward(self, batch): |
|
|
|
|
|
|
|
|
inputs = batch |
|
|
B0, T0, Q0, C0 = batch.shape |
|
|
|
|
|
|
|
|
pos_code = self.position_code(batch) |
|
|
_, T, QC = pos_code.shape |
|
|
pos_code = pos_code.reshape(1, T, Q0, C0) |
|
|
batch = batch + pos_code |
|
|
|
|
|
Q = self["attn_conv_Q"](batch) |
|
|
K = self["attn_conv_K"](batch) |
|
|
V = self["attn_conv_V"](batch) |
|
|
|
|
|
emb_dim = Q.shape[-1] |
|
|
|
|
|
local_mask = self.get_lookahead_mask(batch.shape[1], batch.device) |
|
|
|
|
|
attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) |
|
|
attn_mat.masked_fill_(local_mask == 0, -float("Inf")) |
|
|
attn_mat = F.softmax(attn_mat, dim=2) |
|
|
|
|
|
V = torch.matmul(attn_mat, V) |
|
|
V = V.reshape(-1, T0, V.shape[-1]) |
|
|
V = V.transpose(1, 2) |
|
|
|
|
|
batch = V.reshape(B0, self.n_head, self.n_freqs, self.V_dim, T0) |
|
|
batch = batch.transpose(2, 3) |
|
|
batch = batch.reshape(B0, self.n_head * self.V_dim, self.n_freqs, T0) |
|
|
batch = batch.permute(0, 3, 2, 1) |
|
|
|
|
|
if self.dim_feedforward == -1: |
|
|
batch = self["attn_concat_proj"](batch) |
|
|
else: |
|
|
batch = batch + self._ff_block(batch) |
|
|
batch = batch.reshape(batch.shape[0], batch.shape[1], batch.shape[2] * batch.shape[3]) |
|
|
batch = self.norm(batch) |
|
|
batch = batch.reshape(batch.shape[0], batch.shape[1], Q0, C0) |
|
|
|
|
|
|
|
|
if self.skip_conn: |
|
|
return batch + inputs |
|
|
else: |
|
|
return batch |
|
|
|
|
|
|
|
|
class GridNetBlock(nn.Module): |
|
|
def __getitem__(self, key): |
|
|
return getattr(self, key) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
emb_dim, |
|
|
emb_ks, |
|
|
emb_hs, |
|
|
n_freqs, |
|
|
hidden_channels, |
|
|
lstm_fold_chunk, |
|
|
n_head=4, |
|
|
approx_qk_dim=512, |
|
|
activation="prelu", |
|
|
eps=1e-5, |
|
|
pool="mean", |
|
|
last=False, |
|
|
local_context_len=-1, |
|
|
|
|
|
): |
|
|
super().__init__() |
|
|
bidirectional = True |
|
|
|
|
|
self.global_atten_causal = True |
|
|
|
|
|
self.last = last |
|
|
|
|
|
self.pool = pool |
|
|
|
|
|
self.lstm_fold_chunk = lstm_fold_chunk |
|
|
self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) |
|
|
|
|
|
self.V_dim = emb_dim // n_head |
|
|
self.H = hidden_channels |
|
|
in_channels = emb_dim * emb_ks |
|
|
self.in_channels = in_channels |
|
|
self.n_freqs = n_freqs |
|
|
|
|
|
|
|
|
self.intra_norm = LayerNormalization4D_old(emb_dim, eps=eps) |
|
|
self.intra_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=True) |
|
|
self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs) |
|
|
self.emb_dim = emb_dim |
|
|
self.emb_ks = emb_ks |
|
|
self.emb_hs = emb_hs |
|
|
|
|
|
|
|
|
self.inter_norm = LayerNormalization4D_old(emb_dim, eps=eps) |
|
|
self.inter_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=bidirectional) |
|
|
self.inter_linear = nn.ConvTranspose1d(hidden_channels * (bidirectional + 1), emb_dim, emb_ks, stride=emb_hs) |
|
|
|
|
|
|
|
|
self.pool_atten_causal = Attention_STFT_causal( |
|
|
emb_dim=emb_dim, |
|
|
n_freqs=n_freqs, |
|
|
approx_qk_dim=approx_qk_dim, |
|
|
n_head=n_head, |
|
|
activation=activation, |
|
|
eps=eps, |
|
|
local_context_len=local_context_len, |
|
|
) |
|
|
|
|
|
def _unfold_timedomain(self, x): |
|
|
BQ, C, T = x.shape |
|
|
x = torch.split(x, self.lstm_fold_chunk, dim=-1) |
|
|
x = torch.cat(x, dim=0).reshape(-1, BQ, C, self.lstm_fold_chunk) |
|
|
x = x.permute(1, 0, 3, 2) |
|
|
return x |
|
|
|
|
|
def forward(self, x, init_state=None): |
|
|
"""GridNetBlock Forward. |
|
|
|
|
|
Args: |
|
|
x: [B, C, T, Q] |
|
|
out: [B, C, T, Q] |
|
|
""" |
|
|
B, C, old_T, old_Q = x.shape |
|
|
T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks |
|
|
Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks |
|
|
x = F.pad(x, (0, Q - old_Q, 0, T - old_T)) |
|
|
|
|
|
|
|
|
|
|
|
input_ = x |
|
|
intra_rnn = self.intra_norm(input_) |
|
|
intra_rnn = intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q) |
|
|
|
|
|
intra_rnn = torch.split(intra_rnn, self.emb_ks, dim=-1) |
|
|
intra_rnn = torch.stack(intra_rnn, dim=0) |
|
|
intra_rnn = intra_rnn.permute(1, 2, 3, 0).flatten(1, 2) |
|
|
intra_rnn = intra_rnn.transpose(1, 2) |
|
|
self.intra_rnn.flatten_parameters() |
|
|
|
|
|
|
|
|
intra_rnn, _ = self.intra_rnn(intra_rnn) |
|
|
intra_rnn = intra_rnn.transpose(1, 2) |
|
|
intra_rnn = self.intra_linear(intra_rnn) |
|
|
intra_rnn = intra_rnn.view([B, T, C, Q]) |
|
|
intra_rnn = intra_rnn.transpose(1, 2).contiguous() |
|
|
intra_rnn = intra_rnn + input_ |
|
|
intra_rnn = intra_rnn[:, :, :, :old_Q] |
|
|
Q = old_Q |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inter_rnn = self.inter_norm(intra_rnn) |
|
|
inter_rnn = inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T) |
|
|
|
|
|
|
|
|
inter_rnn = self._unfold_timedomain(inter_rnn) |
|
|
|
|
|
BQ, NUM_CHUNK, CHUNKSIZE, C = inter_rnn.shape |
|
|
|
|
|
inter_rnn = inter_rnn.reshape(BQ * NUM_CHUNK, CHUNKSIZE, C) |
|
|
inter_rnn = inter_rnn.transpose(2, 1) |
|
|
input_ = inter_rnn |
|
|
|
|
|
inter_rnn = torch.split(inter_rnn, self.emb_ks, dim=-1) |
|
|
|
|
|
inter_rnn = torch.stack(inter_rnn, dim=0) |
|
|
inter_rnn = inter_rnn.permute(1, 2, 3, 0) |
|
|
|
|
|
BF, C, EO, _T = inter_rnn.shape |
|
|
inter_rnn = inter_rnn.reshape(BF, C * EO, _T) |
|
|
|
|
|
inter_rnn = inter_rnn.transpose(1, 2) |
|
|
|
|
|
self.inter_rnn.flatten_parameters() |
|
|
inter_rnn, _ = self.inter_rnn(inter_rnn) |
|
|
inter_rnn = inter_rnn.transpose(1, 2) |
|
|
inter_rnn = self.inter_linear(inter_rnn) |
|
|
inter_rnn = inter_rnn + input_ |
|
|
|
|
|
inter_rnn = inter_rnn.reshape(B, Q, NUM_CHUNK, C, CHUNKSIZE) |
|
|
inter_rnn = inter_rnn.permute(0, 1, 2, 4, 3) |
|
|
|
|
|
input_ = inter_rnn |
|
|
if self.pool == "mean": |
|
|
inter_rnn = torch.mean(inter_rnn, dim=3) |
|
|
elif self.pool == "max": |
|
|
inter_rnn, _ = torch.max(inter_rnn, dim=3) |
|
|
else: |
|
|
raise ValueError("INvalid pool type!") |
|
|
|
|
|
|
|
|
|
|
|
inter_rnn = inter_rnn.transpose(1, 2) |
|
|
inter_rnn = self.pool_atten_causal(inter_rnn) |
|
|
inter_rnn = inter_rnn.transpose(1, 2) |
|
|
|
|
|
if self.last == True: |
|
|
return inter_rnn, init_state |
|
|
|
|
|
else: |
|
|
inter_rnn = inter_rnn.unsqueeze(3) |
|
|
inter_rnn = input_ + inter_rnn |
|
|
|
|
|
inter_rnn = inter_rnn.reshape(B, Q, T, C) |
|
|
inter_rnn = inter_rnn.permute(0, 3, 2, 1) |
|
|
inter_rnn = inter_rnn[..., :old_T, :] |
|
|
|
|
|
|
|
|
return inter_rnn, init_state |
|
|
|