sereich's picture
Initial commit of Radio Upscaling UI (minus models)
f113387
import torch
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
import math
from src.models.snake import Snake
from src.models.utils import unfold
import typing as tp
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class BLSTM(nn.Module):
"""
BiLSTM with same hidden units as input dim.
If `max_steps` is not None, input will be splitting in overlapping
chunks and the LSTM applied separately on each chunk.
"""
def __init__(self, dim, layers=1, max_steps=None, skip=False):
super().__init__()
assert max_steps is None or max_steps % 4 == 0
self.max_steps = max_steps
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = nn.Linear(2 * dim, dim)
self.skip = skip
def forward(self, x):
B, C, T = x.shape
y = x
framed = False
if self.max_steps is not None and T > self.max_steps:
width = self.max_steps
stride = width // 2
frames = unfold(x, width, stride)
nframes = frames.shape[2]
framed = True
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
x = x.permute(2, 0, 1)
x = self.lstm(x)[0]
x = self.linear(x)
x = x.permute(1, 2, 0)
if framed:
out = []
frames = x.reshape(B, -1, C, width)
limit = stride // 2
for k in range(nframes):
if k == 0:
out.append(frames[:, k, :, :-limit])
elif k == nframes - 1:
out.append(frames[:, k, :, limit:])
else:
out.append(frames[:, k, :, limit:-limit])
out = torch.cat(out, -1)
out = out[..., :T]
x = out
if self.skip:
x = x + y
return x
class LocalState(nn.Module):
"""Local state allows to have attention based only on data (no positional embedding),
but while setting a constraint on the time window (e.g. decaying penalty term).
Also a failed experiments with trying to provide some frequency based attention.
"""
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
super().__init__()
assert channels % heads == 0, (channels, heads)
self.heads = heads
self.nfreqs = nfreqs
self.ndecay = ndecay
self.content = nn.Conv1d(channels, channels, 1)
self.query = nn.Conv1d(channels, channels, 1)
self.key = nn.Conv1d(channels, channels, 1)
if nfreqs:
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
if ndecay:
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
self.query_decay.weight.data *= 0.01
assert self.query_decay.bias is not None # stupid type checker
self.query_decay.bias.data[:] = -2
# self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
self.proj = nn.Conv1d(channels, channels, 1)
def forward(self, x):
B, C, T = x.shape
heads = self.heads
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
# left index are keys, right index are queries
delta = indexes[:, None] - indexes[None, :]
queries = self.query(x).view(B, heads, -1, T)
keys = self.key(x).view(B, heads, -1, T)
# t are keys, s are queries
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
dots /= keys.shape[2] ** 0.5
if self.nfreqs:
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
tmp = torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
dots += tmp
if self.ndecay:
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
decay_q = self.query_decay(x).view(B, heads, -1, T)
decay_q = torch.sigmoid(decay_q) / 2
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay ** 0.5
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
# Kill self reference.
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
weights = torch.softmax(dots, dim=2)
content = self.content(x).view(B, heads, -1, T)
result = torch.einsum("bhts,bhct->bhcs", weights, content)
result = result.reshape(B, -1, T)
return x + self.proj(result)
class LayerScale(nn.Module):
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
This rescales diagonaly residual outputs close to 0 initially, then learnt.
"""
def __init__(self, channels: int, init: float = 0):
super().__init__()
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
self.scale.data[:] = init
def forward(self, x):
return self.scale[:, None] * x
class DConv(nn.Module):
"""
New residual branches in each encoder layer.
This alternates dilated convolutions, potentially with LSTMs and attention.
Also before entering each residual branch, dimension is projected on a smaller subspace,
e.g. of dim `channels // compress`.
"""
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
norm=True, time_attn=False, heads=4, ndecay=4, lstm=False,
act_func='gelu', freq_dim=None, reshape=False,
kernel=3, dilate=True):
"""
Args:
channels: input/output channels for residual branch.
compress: amount of channel compression inside the branch.
depth: number of layers in the residual branch. Each layer has its own
projection, and potentially LSTM and attention.
init: initial scale for LayerNorm.
norm: use GroupNorm.
time_attn: use LocalAttention.
heads: number of heads for the LocalAttention.
ndecay: number of decay controls in the LocalAttention.
lstm: use LSTM.
gelu: Use GELU activation.
kernel: kernel size for the (dilated) convolutions.
dilate: if true, use dilation, increasing with the depth.
"""
super().__init__()
assert kernel % 2 == 1
self.channels = channels
self.compress = compress
self.depth = abs(depth)
dilate = depth > 0
self.time_attn = time_attn
self.lstm = lstm
self.reshape = reshape
self.act_func = act_func
self.freq_dim = freq_dim
norm_fn: tp.Callable[[int], nn.Module]
norm_fn = lambda d: nn.Identity() # noqa
if norm:
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
self.hidden = int(channels / compress)
act: tp.Type[nn.Module]
if act_func == 'gelu':
act = nn.GELU
elif act_func == 'snake':
act = Snake
else:
act = nn.ReLU
self.layers = nn.ModuleList([])
for d in range(self.depth):
layer = nn.ModuleDict()
dilation = 2 ** d if dilate else 1
padding = dilation * (kernel // 2)
conv1 = nn.ModuleList([nn.Conv1d(channels, self.hidden, kernel, dilation=dilation, padding=padding),
norm_fn(self.hidden)])
act_layer = act(freq_dim) if act_func == 'snake' else act()
conv2 = nn.ModuleList([nn.Conv1d(self.hidden, 2 * channels, 1),
norm_fn(2 * channels), nn.GLU(1),
LayerScale(channels, init)])
layer.update({'conv1': nn.Sequential(*conv1), 'act': act_layer, 'conv2': nn.Sequential(*conv2)})
if lstm:
layer.update({'lstm': BLSTM(self.hidden, layers=2, max_steps=200, skip=True)})
if time_attn:
layer.update({'time_attn': LocalState(self.hidden, heads=heads, ndecay=ndecay)})
self.layers.append(layer)
def forward(self, x):
if self.reshape:
B, C, Fr, T = x.shape
x = x.permute(0, 2, 1, 3).reshape(-1, C, T)
for layer in self.layers:
skip = x
x = layer['conv1'](x)
if self.act_func == 'snake' and self.reshape:
x = x.view(B, Fr, self.hidden, T).permute(0, 2, 3, 1)
x = layer['act'](x)
if self.act_func == 'snake' and self.reshape:
x = x.permute(0, 3, 1, 2).reshape(-1, self.hidden, T)
if self.lstm:
x = layer['lstm'](x)
if self.time_attn:
x = layer['time_attn'](x)
x = layer['conv2'](x)
x = skip + x
if self.reshape:
x = x.view(B, Fr, C, T).permute(0, 2, 1, 3)
return x
class ScaledEmbedding(nn.Module):
"""
Boost learning rate for embeddings (with `scale`).
Also, can make embeddings continuous with `smooth`.
"""
def __init__(self, num_embeddings: int, embedding_dim: int,
scale: float = 10., smooth=False):
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
if smooth:
weight = torch.cumsum(self.embedding.weight.data, dim=0)
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
self.embedding.weight.data[:] = weight
self.embedding.weight.data /= scale
self.scale = scale
@property
def weight(self):
return self.embedding.weight * self.scale
def forward(self, x):
out = self.embedding(x) * self.scale
return out
class FTB(nn.Module):
def __init__(self, input_dim=257, in_channel=9, r_channel=5):
super(FTB, self).__init__()
self.input_dim = input_dim
self.in_channel = in_channel
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]),
nn.BatchNorm2d(r_channel),
nn.ReLU()
)
self.conv1d = nn.Sequential(
nn.Conv1d(r_channel * input_dim, in_channel, kernel_size=9, padding=4),
nn.BatchNorm1d(in_channel),
nn.ReLU()
)
self.freq_fc = nn.Linear(input_dim, input_dim, bias=False)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]),
nn.BatchNorm2d(in_channel),
nn.ReLU()
)
def forward(self, inputs):
'''
inputs should be [Batch, Ca, Dim, Time]
'''
# T-F attention
conv1_out = self.conv1(inputs)
B, C, D, T = conv1_out.size()
reshape1_out = torch.reshape(conv1_out, [B, C * D, T])
conv1d_out = self.conv1d(reshape1_out)
conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T])
# now is also [B,C,D,T]
att_out = conv1d_out * inputs
# tranpose to [B,C,T,D]
att_out = torch.transpose(att_out, 2, 3)
freqfc_out = self.freq_fc(att_out)
att_out = torch.transpose(freqfc_out, 2, 3)
cat_out = torch.cat([att_out, inputs], 1)
outputs = self.conv2(cat_out)
return outputs