|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn.utils.parametrizations import weight_norm |
|
|
|
|
|
import math |
|
|
|
|
|
from src.models.snake import Snake |
|
|
from src.models.utils import unfold |
|
|
|
|
|
import typing as tp |
|
|
|
|
|
def WNConv1d(*args, **kwargs): |
|
|
return weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
|
|
|
|
|
def WNConvTranspose1d(*args, **kwargs): |
|
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
|
|
class BLSTM(nn.Module): |
|
|
""" |
|
|
BiLSTM with same hidden units as input dim. |
|
|
If `max_steps` is not None, input will be splitting in overlapping |
|
|
chunks and the LSTM applied separately on each chunk. |
|
|
""" |
|
|
|
|
|
def __init__(self, dim, layers=1, max_steps=None, skip=False): |
|
|
super().__init__() |
|
|
assert max_steps is None or max_steps % 4 == 0 |
|
|
self.max_steps = max_steps |
|
|
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) |
|
|
self.linear = nn.Linear(2 * dim, dim) |
|
|
self.skip = skip |
|
|
|
|
|
def forward(self, x): |
|
|
B, C, T = x.shape |
|
|
y = x |
|
|
framed = False |
|
|
if self.max_steps is not None and T > self.max_steps: |
|
|
width = self.max_steps |
|
|
stride = width // 2 |
|
|
frames = unfold(x, width, stride) |
|
|
nframes = frames.shape[2] |
|
|
framed = True |
|
|
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) |
|
|
|
|
|
x = x.permute(2, 0, 1) |
|
|
|
|
|
x = self.lstm(x)[0] |
|
|
x = self.linear(x) |
|
|
x = x.permute(1, 2, 0) |
|
|
if framed: |
|
|
out = [] |
|
|
frames = x.reshape(B, -1, C, width) |
|
|
limit = stride // 2 |
|
|
for k in range(nframes): |
|
|
if k == 0: |
|
|
out.append(frames[:, k, :, :-limit]) |
|
|
elif k == nframes - 1: |
|
|
out.append(frames[:, k, :, limit:]) |
|
|
else: |
|
|
out.append(frames[:, k, :, limit:-limit]) |
|
|
out = torch.cat(out, -1) |
|
|
out = out[..., :T] |
|
|
x = out |
|
|
if self.skip: |
|
|
x = x + y |
|
|
return x |
|
|
|
|
|
|
|
|
class LocalState(nn.Module): |
|
|
"""Local state allows to have attention based only on data (no positional embedding), |
|
|
but while setting a constraint on the time window (e.g. decaying penalty term). |
|
|
Also a failed experiments with trying to provide some frequency based attention. |
|
|
""" |
|
|
|
|
|
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): |
|
|
super().__init__() |
|
|
assert channels % heads == 0, (channels, heads) |
|
|
self.heads = heads |
|
|
self.nfreqs = nfreqs |
|
|
self.ndecay = ndecay |
|
|
self.content = nn.Conv1d(channels, channels, 1) |
|
|
self.query = nn.Conv1d(channels, channels, 1) |
|
|
self.key = nn.Conv1d(channels, channels, 1) |
|
|
if nfreqs: |
|
|
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) |
|
|
if ndecay: |
|
|
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) |
|
|
|
|
|
self.query_decay.weight.data *= 0.01 |
|
|
assert self.query_decay.bias is not None |
|
|
self.query_decay.bias.data[:] = -2 |
|
|
|
|
|
self.proj = nn.Conv1d(channels, channels, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
B, C, T = x.shape |
|
|
heads = self.heads |
|
|
indexes = torch.arange(T, device=x.device, dtype=x.dtype) |
|
|
|
|
|
delta = indexes[:, None] - indexes[None, :] |
|
|
|
|
|
queries = self.query(x).view(B, heads, -1, T) |
|
|
keys = self.key(x).view(B, heads, -1, T) |
|
|
|
|
|
dots = torch.einsum("bhct,bhcs->bhts", keys, queries) |
|
|
dots /= keys.shape[2] ** 0.5 |
|
|
if self.nfreqs: |
|
|
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) |
|
|
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) |
|
|
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5 |
|
|
tmp = torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) |
|
|
dots += tmp |
|
|
if self.ndecay: |
|
|
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) |
|
|
decay_q = self.query_decay(x).view(B, heads, -1, T) |
|
|
decay_q = torch.sigmoid(decay_q) / 2 |
|
|
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay ** 0.5 |
|
|
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) |
|
|
|
|
|
|
|
|
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) |
|
|
weights = torch.softmax(dots, dim=2) |
|
|
|
|
|
content = self.content(x).view(B, heads, -1, T) |
|
|
result = torch.einsum("bhts,bhct->bhcs", weights, content) |
|
|
|
|
|
result = result.reshape(B, -1, T) |
|
|
return x + self.proj(result) |
|
|
|
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). |
|
|
This rescales diagonaly residual outputs close to 0 initially, then learnt. |
|
|
""" |
|
|
|
|
|
def __init__(self, channels: int, init: float = 0): |
|
|
super().__init__() |
|
|
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) |
|
|
self.scale.data[:] = init |
|
|
|
|
|
def forward(self, x): |
|
|
return self.scale[:, None] * x |
|
|
|
|
|
|
|
|
class DConv(nn.Module): |
|
|
""" |
|
|
New residual branches in each encoder layer. |
|
|
This alternates dilated convolutions, potentially with LSTMs and attention. |
|
|
Also before entering each residual branch, dimension is projected on a smaller subspace, |
|
|
e.g. of dim `channels // compress`. |
|
|
""" |
|
|
|
|
|
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, |
|
|
norm=True, time_attn=False, heads=4, ndecay=4, lstm=False, |
|
|
act_func='gelu', freq_dim=None, reshape=False, |
|
|
kernel=3, dilate=True): |
|
|
""" |
|
|
Args: |
|
|
channels: input/output channels for residual branch. |
|
|
compress: amount of channel compression inside the branch. |
|
|
depth: number of layers in the residual branch. Each layer has its own |
|
|
projection, and potentially LSTM and attention. |
|
|
init: initial scale for LayerNorm. |
|
|
norm: use GroupNorm. |
|
|
time_attn: use LocalAttention. |
|
|
heads: number of heads for the LocalAttention. |
|
|
ndecay: number of decay controls in the LocalAttention. |
|
|
lstm: use LSTM. |
|
|
gelu: Use GELU activation. |
|
|
kernel: kernel size for the (dilated) convolutions. |
|
|
dilate: if true, use dilation, increasing with the depth. |
|
|
""" |
|
|
|
|
|
super().__init__() |
|
|
assert kernel % 2 == 1 |
|
|
self.channels = channels |
|
|
self.compress = compress |
|
|
self.depth = abs(depth) |
|
|
dilate = depth > 0 |
|
|
|
|
|
self.time_attn = time_attn |
|
|
self.lstm = lstm |
|
|
self.reshape = reshape |
|
|
self.act_func = act_func |
|
|
self.freq_dim = freq_dim |
|
|
|
|
|
norm_fn: tp.Callable[[int], nn.Module] |
|
|
norm_fn = lambda d: nn.Identity() |
|
|
if norm: |
|
|
norm_fn = lambda d: nn.GroupNorm(1, d) |
|
|
|
|
|
self.hidden = int(channels / compress) |
|
|
|
|
|
act: tp.Type[nn.Module] |
|
|
if act_func == 'gelu': |
|
|
act = nn.GELU |
|
|
elif act_func == 'snake': |
|
|
act = Snake |
|
|
else: |
|
|
act = nn.ReLU |
|
|
|
|
|
self.layers = nn.ModuleList([]) |
|
|
for d in range(self.depth): |
|
|
layer = nn.ModuleDict() |
|
|
dilation = 2 ** d if dilate else 1 |
|
|
padding = dilation * (kernel // 2) |
|
|
conv1 = nn.ModuleList([nn.Conv1d(channels, self.hidden, kernel, dilation=dilation, padding=padding), |
|
|
norm_fn(self.hidden)]) |
|
|
act_layer = act(freq_dim) if act_func == 'snake' else act() |
|
|
conv2 = nn.ModuleList([nn.Conv1d(self.hidden, 2 * channels, 1), |
|
|
norm_fn(2 * channels), nn.GLU(1), |
|
|
LayerScale(channels, init)]) |
|
|
|
|
|
layer.update({'conv1': nn.Sequential(*conv1), 'act': act_layer, 'conv2': nn.Sequential(*conv2)}) |
|
|
if lstm: |
|
|
layer.update({'lstm': BLSTM(self.hidden, layers=2, max_steps=200, skip=True)}) |
|
|
if time_attn: |
|
|
layer.update({'time_attn': LocalState(self.hidden, heads=heads, ndecay=ndecay)}) |
|
|
|
|
|
self.layers.append(layer) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
if self.reshape: |
|
|
B, C, Fr, T = x.shape |
|
|
x = x.permute(0, 2, 1, 3).reshape(-1, C, T) |
|
|
|
|
|
for layer in self.layers: |
|
|
skip = x |
|
|
|
|
|
x = layer['conv1'](x) |
|
|
|
|
|
if self.act_func == 'snake' and self.reshape: |
|
|
x = x.view(B, Fr, self.hidden, T).permute(0, 2, 3, 1) |
|
|
x = layer['act'](x) |
|
|
if self.act_func == 'snake' and self.reshape: |
|
|
x = x.permute(0, 3, 1, 2).reshape(-1, self.hidden, T) |
|
|
|
|
|
if self.lstm: |
|
|
x = layer['lstm'](x) |
|
|
if self.time_attn: |
|
|
x = layer['time_attn'](x) |
|
|
|
|
|
x = layer['conv2'](x) |
|
|
x = skip + x |
|
|
|
|
|
if self.reshape: |
|
|
x = x.view(B, Fr, C, T).permute(0, 2, 1, 3) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class ScaledEmbedding(nn.Module): |
|
|
""" |
|
|
Boost learning rate for embeddings (with `scale`). |
|
|
Also, can make embeddings continuous with `smooth`. |
|
|
""" |
|
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, |
|
|
scale: float = 10., smooth=False): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(num_embeddings, embedding_dim) |
|
|
if smooth: |
|
|
weight = torch.cumsum(self.embedding.weight.data, dim=0) |
|
|
|
|
|
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] |
|
|
self.embedding.weight.data[:] = weight |
|
|
self.embedding.weight.data /= scale |
|
|
self.scale = scale |
|
|
|
|
|
@property |
|
|
def weight(self): |
|
|
return self.embedding.weight * self.scale |
|
|
|
|
|
def forward(self, x): |
|
|
out = self.embedding(x) * self.scale |
|
|
return out |
|
|
|
|
|
|
|
|
class FTB(nn.Module): |
|
|
|
|
|
def __init__(self, input_dim=257, in_channel=9, r_channel=5): |
|
|
super(FTB, self).__init__() |
|
|
self.input_dim = input_dim |
|
|
self.in_channel = in_channel |
|
|
self.conv1 = nn.Sequential( |
|
|
nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]), |
|
|
nn.BatchNorm2d(r_channel), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
self.conv1d = nn.Sequential( |
|
|
nn.Conv1d(r_channel * input_dim, in_channel, kernel_size=9, padding=4), |
|
|
nn.BatchNorm1d(in_channel), |
|
|
nn.ReLU() |
|
|
) |
|
|
self.freq_fc = nn.Linear(input_dim, input_dim, bias=False) |
|
|
|
|
|
self.conv2 = nn.Sequential( |
|
|
nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]), |
|
|
nn.BatchNorm2d(in_channel), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
def forward(self, inputs): |
|
|
''' |
|
|
inputs should be [Batch, Ca, Dim, Time] |
|
|
''' |
|
|
|
|
|
conv1_out = self.conv1(inputs) |
|
|
B, C, D, T = conv1_out.size() |
|
|
reshape1_out = torch.reshape(conv1_out, [B, C * D, T]) |
|
|
conv1d_out = self.conv1d(reshape1_out) |
|
|
conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T]) |
|
|
|
|
|
|
|
|
att_out = conv1d_out * inputs |
|
|
|
|
|
|
|
|
att_out = torch.transpose(att_out, 2, 3) |
|
|
freqfc_out = self.freq_fc(att_out) |
|
|
att_out = torch.transpose(freqfc_out, 2, 3) |
|
|
|
|
|
cat_out = torch.cat([att_out, inputs], 1) |
|
|
outputs = self.conv2(cat_out) |
|
|
return outputs |
|
|
|