import math import torch from torch import nn, einsum import torch.nn.functional as F import torch import torch.nn as nn from torch.nn.modules.transformer import _get_clones from torch import Tensor from einops import rearrange from einops.layers.torch import Rearrange # helper functions def exists(val): return val is not None def default(val, d): return val if exists(val) else d def calc_same_padding(kernel_size): pad = kernel_size // 2 return (pad, pad - (kernel_size + 1) % 2) # helper classes class Swish(nn.Module): def forward(self, x): return x * x.sigmoid() class GLU(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): out, gate = x.chunk(2, dim=self.dim) return out * gate.sigmoid() class DepthWiseConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size, padding): super().__init__() self.padding = padding self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) def forward(self, x): x = F.pad(x, self.padding) return self.conv(x) # attention, feedforward, and conv module class Scale(nn.Module): def __init__(self, scale, fn): super().__init__() self.fn = fn self.scale = scale def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.scale class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) def forward(self, x, **kwargs): x = self.norm(x) return self.fn(x, **kwargs) class Attention(nn.Module): # Head Token attention: https://arxiv.org/pdf/2210.05958.pdf def __init__(self, dim, heads=8, dim_head=64, qkv_bias=False, dropout=0., proj_drop=0.): super().__init__() self.num_heads = heads inner_dim = dim_head * heads self.scale = dim_head ** -0.5 self.qkv = nn.Linear(dim, inner_dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(dropout) self.proj = nn.Linear(inner_dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.act = nn.GELU() self.ht_proj = nn.Linear(dim_head, dim,bias=True) self.ht_norm = nn.LayerNorm(dim_head) self.pos_embed = nn.Parameter(torch.zeros(1, self.num_heads, dim)) def forward(self, x, mask=None): B, N, C = x.shape # head token head_pos = self.pos_embed.expand(x.shape[0], -1, -1) x_ = x.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) x_ = x_.mean(dim=2) # now the shape is [B, h, 1, d//h] x_ = self.ht_proj(x_).reshape(B, -1, self.num_heads, C // self.num_heads) x_ = self.act(self.ht_norm(x_)).flatten(2) x_ = x_ + head_pos x = torch.cat([x, x_], dim=1) # normal mhsa qkv = self.qkv(x).reshape(B, N+self.num_heads, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) # attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N+self.num_heads, C) x = self.proj(x) # merge head tokens into cls token cls, patch, ht = torch.split(x, [1, N-1, self.num_heads], dim=1) cls = cls + torch.mean(ht, dim=1, keepdim=True) + torch.mean(patch, dim=1, keepdim=True) x = torch.cat([cls, patch], dim=1) x = self.proj_drop(x) return x, attn class FeedForward(nn.Module): def __init__( self, dim, mult = 4, dropout = 0. ): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult), Swish(), nn.Dropout(dropout), nn.Linear(dim * mult, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class ConformerConvModule(nn.Module): def __init__( self, dim, causal = False, expansion_factor = 2, kernel_size = 31, dropout = 0. ): super().__init__() inner_dim = dim * expansion_factor padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) self.net = nn.Sequential( nn.LayerNorm(dim), Rearrange('b n c -> b c n'), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), Swish(), nn.Conv1d(inner_dim, dim, 1), Rearrange('b c n -> b n c'), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) # Conformer Block class ConformerBlock(nn.Module): def __init__( self, *, dim, dim_head = 64, heads = 8, ff_mult = 4, conv_expansion_factor = 2, conv_kernel_size = 31, attn_dropout = 0., ff_dropout = 0., conv_dropout = 0., conv_causal = False ): super().__init__() self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout) self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) self.attn = PreNorm(dim, self.attn) self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) self.post_norm = nn.LayerNorm(dim) def forward(self, x, mask = None): x = self.ff1(x) + x attn_x, attn_weight = self.attn(x, mask = mask) x = attn_x + x x = self.conv(x) + x x = self.ff2(x) + x x = self.post_norm(x) return x, attn_weight # Conformer class Conformer(nn.Module): def __init__( self, dim, *, depth, dim_head = 64, heads = 8, ff_mult = 4, conv_expansion_factor = 2, conv_kernel_size = 31, attn_dropout = 0., ff_dropout = 0., conv_dropout = 0., conv_causal = False ): super().__init__() self.dim = dim self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(ConformerBlock( dim = dim, dim_head = dim_head, heads = heads, ff_mult = ff_mult, conv_expansion_factor = conv_expansion_factor, conv_kernel_size = conv_kernel_size, conv_causal = conv_causal )) def forward(self, x): for block in self.layers: x = block(x) return x def sinusoidal_embedding(n_channels, dim): pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] for p in range(n_channels)]) pe[:, 0::2] = torch.sin(pe[:, 0::2]) pe[:, 1::2] = torch.cos(pe[:, 1::2]) return pe.unsqueeze(0) class FinalConformer(nn.Module): def __init__(self, emb_size=128, heads=4, ffmult=4, exp_fac=2, kernel_size=16, n_encoders=1): super(FinalConformer, self).__init__() self.dim_head=int(emb_size/heads) self.dim=emb_size self.heads=heads self.kernel_size=kernel_size self.n_encoders=n_encoders self.positional_emb = nn.Parameter(sinusoidal_embedding(10000, emb_size), requires_grad=False) self.encoder_blocks=_get_clones(ConformerBlock( dim = emb_size, dim_head=self.dim_head, heads= heads, ff_mult = ffmult, conv_expansion_factor = exp_fac, conv_kernel_size = kernel_size), n_encoders) self.class_token = nn.Parameter(torch.rand(1, emb_size)) self.fc5 = nn.Linear(emb_size, 2) def forward(self, x): # x shape [bs, tiempo, frecuencia] x = x + self.positional_emb[:, :x.size(1), :] x = torch.stack([torch.vstack((self.class_token, x[i])) for i in range(len(x))])#[bs,1+tiempo,emb_size] list_attn_weight = [] for layer in self.encoder_blocks: x, attn_weight = layer(x) #[bs,1+tiempo,emb_size] list_attn_weight.append(attn_weight) embedding=x[:,0,:] #[bs, emb_size] out=self.fc5(embedding) #[bs,2] return out, list_attn_weight