| import math |
| import torch |
| from torch import nn, einsum |
| import torch.nn.functional as F |
| import torch |
| import torch.nn as nn |
| from torch.nn.modules.transformer import _get_clones |
| from torch import Tensor |
| from einops import rearrange |
| from einops.layers.torch import Rearrange |
|
|
| |
|
|
| def exists(val): |
| return val is not None |
|
|
| def default(val, d): |
| return val if exists(val) else d |
|
|
| def calc_same_padding(kernel_size): |
| pad = kernel_size // 2 |
| return (pad, pad - (kernel_size + 1) % 2) |
|
|
| |
|
|
| class Swish(nn.Module): |
| def forward(self, x): |
| return x * x.sigmoid() |
|
|
| class GLU(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, x): |
| out, gate = x.chunk(2, dim=self.dim) |
| return out * gate.sigmoid() |
|
|
| class DepthWiseConv1d(nn.Module): |
| def __init__(self, chan_in, chan_out, kernel_size, padding): |
| super().__init__() |
| self.padding = padding |
| self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) |
|
|
| def forward(self, x): |
| x = F.pad(x, self.padding) |
| return self.conv(x) |
|
|
| |
|
|
| class Scale(nn.Module): |
| def __init__(self, scale, fn): |
| super().__init__() |
| self.fn = fn |
| self.scale = scale |
|
|
| def forward(self, x, **kwargs): |
| return self.fn(x, **kwargs) * self.scale |
|
|
| class PreNorm(nn.Module): |
| def __init__(self, dim, fn): |
| super().__init__() |
| self.fn = fn |
| self.norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x, **kwargs): |
| x = self.norm(x) |
| return self.fn(x, **kwargs) |
|
|
| class Attention(nn.Module): |
| |
| def __init__(self, dim, heads=8, dim_head=64, qkv_bias=False, dropout=0., proj_drop=0.): |
| super().__init__() |
| self.num_heads = heads |
| inner_dim = dim_head * heads |
| self.scale = dim_head ** -0.5 |
|
|
| self.qkv = nn.Linear(dim, inner_dim * 3, bias=qkv_bias) |
|
|
| self.attn_drop = nn.Dropout(dropout) |
| self.proj = nn.Linear(inner_dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
| |
| self.act = nn.GELU() |
| self.ht_proj = nn.Linear(dim_head, dim,bias=True) |
| self.ht_norm = nn.LayerNorm(dim_head) |
| self.pos_embed = nn.Parameter(torch.zeros(1, self.num_heads, dim)) |
| |
| def forward(self, x, mask=None): |
| B, N, C = x.shape |
|
|
| |
| head_pos = self.pos_embed.expand(x.shape[0], -1, -1) |
| x_ = x.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
| x_ = x_.mean(dim=2) |
| x_ = self.ht_proj(x_).reshape(B, -1, self.num_heads, C // self.num_heads) |
| x_ = self.act(self.ht_norm(x_)).flatten(2) |
| x_ = x_ + head_pos |
| x = torch.cat([x, x_], dim=1) |
| |
| |
| qkv = self.qkv(x).reshape(B, N+self.num_heads, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| |
| |
| x = (attn @ v).transpose(1, 2).reshape(B, N+self.num_heads, C) |
| x = self.proj(x) |
| |
| |
| cls, patch, ht = torch.split(x, [1, N-1, self.num_heads], dim=1) |
| cls = cls + torch.mean(ht, dim=1, keepdim=True) + torch.mean(patch, dim=1, keepdim=True) |
| x = torch.cat([cls, patch], dim=1) |
|
|
| x = self.proj_drop(x) |
|
|
| return x, attn |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__( |
| self, |
| dim, |
| mult = 4, |
| dropout = 0. |
| ): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(dim, dim * mult), |
| Swish(), |
| nn.Dropout(dropout), |
| nn.Linear(dim * mult, dim), |
| nn.Dropout(dropout) |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| class ConformerConvModule(nn.Module): |
| def __init__( |
| self, |
| dim, |
| causal = False, |
| expansion_factor = 2, |
| kernel_size = 31, |
| dropout = 0. |
| ): |
| super().__init__() |
|
|
| inner_dim = dim * expansion_factor |
| padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) |
|
|
| self.net = nn.Sequential( |
| nn.LayerNorm(dim), |
| Rearrange('b n c -> b c n'), |
| nn.Conv1d(dim, inner_dim * 2, 1), |
| GLU(dim=1), |
| DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), |
| nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), |
| Swish(), |
| nn.Conv1d(inner_dim, dim, 1), |
| Rearrange('b c n -> b n c'), |
| nn.Dropout(dropout) |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| |
|
|
| class ConformerBlock(nn.Module): |
| def __init__( |
| self, |
| *, |
| dim, |
| dim_head = 64, |
| heads = 8, |
| ff_mult = 4, |
| conv_expansion_factor = 2, |
| conv_kernel_size = 31, |
| attn_dropout = 0., |
| ff_dropout = 0., |
| conv_dropout = 0., |
| conv_causal = False |
| ): |
| super().__init__() |
| self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) |
| self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) |
| self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout) |
| self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) |
|
|
| self.attn = PreNorm(dim, self.attn) |
| self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) |
| self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) |
|
|
| self.post_norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x, mask = None): |
| x = self.ff1(x) + x |
| attn_x, attn_weight = self.attn(x, mask = mask) |
| x = attn_x + x |
| x = self.conv(x) + x |
| x = self.ff2(x) + x |
| x = self.post_norm(x) |
| return x, attn_weight |
|
|
| |
|
|
| class Conformer(nn.Module): |
| def __init__( |
| self, |
| dim, |
| *, |
| depth, |
| dim_head = 64, |
| heads = 8, |
| ff_mult = 4, |
| conv_expansion_factor = 2, |
| conv_kernel_size = 31, |
| attn_dropout = 0., |
| ff_dropout = 0., |
| conv_dropout = 0., |
| conv_causal = False |
| ): |
| super().__init__() |
| self.dim = dim |
| self.layers = nn.ModuleList([]) |
|
|
| for _ in range(depth): |
| self.layers.append(ConformerBlock( |
| dim = dim, |
| dim_head = dim_head, |
| heads = heads, |
| ff_mult = ff_mult, |
| conv_expansion_factor = conv_expansion_factor, |
| conv_kernel_size = conv_kernel_size, |
| conv_causal = conv_causal |
|
|
| )) |
|
|
| def forward(self, x): |
|
|
| for block in self.layers: |
| x = block(x) |
|
|
| return x |
| |
|
|
|
|
| def sinusoidal_embedding(n_channels, dim): |
| pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] |
| for p in range(n_channels)]) |
| pe[:, 0::2] = torch.sin(pe[:, 0::2]) |
| pe[:, 1::2] = torch.cos(pe[:, 1::2]) |
| return pe.unsqueeze(0) |
|
|
| class FinalConformer(nn.Module): |
| def __init__(self, emb_size=128, heads=4, ffmult=4, exp_fac=2, kernel_size=16, n_encoders=1): |
| super(FinalConformer, self).__init__() |
| self.dim_head=int(emb_size/heads) |
| self.dim=emb_size |
| self.heads=heads |
| self.kernel_size=kernel_size |
| self.n_encoders=n_encoders |
| self.positional_emb = nn.Parameter(sinusoidal_embedding(10000, emb_size), requires_grad=False) |
| self.encoder_blocks=_get_clones(ConformerBlock( dim = emb_size, dim_head=self.dim_head, heads= heads, |
| ff_mult = ffmult, conv_expansion_factor = exp_fac, conv_kernel_size = kernel_size), |
| n_encoders) |
| self.class_token = nn.Parameter(torch.rand(1, emb_size)) |
| self.fc5 = nn.Linear(emb_size, 2) |
|
|
| def forward(self, x): |
| x = x + self.positional_emb[:, :x.size(1), :] |
| x = torch.stack([torch.vstack((self.class_token, x[i])) for i in range(len(x))]) |
| list_attn_weight = [] |
| for layer in self.encoder_blocks: |
| x, attn_weight = layer(x) |
| list_attn_weight.append(attn_weight) |
| embedding=x[:,0,:] |
| out=self.fc5(embedding) |
| return out, list_attn_weight |
|
|
|
|