| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from einops import rearrange
|
| from einops.layers.torch import Rearrange, Reduce
|
|
|
| class PatchEmbedding(nn.Module):
|
| def __init__(self, emb_size=40, patch_size=25, in_channels=1):
|
| super().__init__()
|
| self.projection = nn.Sequential(
|
| nn.Conv2d(in_channels, emb_size, kernel_size=(1, patch_size), stride=(1, patch_size)),
|
| Rearrange('b e (h) (w) -> b (h w) e'),
|
| )
|
| self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
|
| self.positions = nn.Parameter(torch.randn(2000, emb_size))
|
|
|
| def forward(self, x):
|
| b, _, _, _ = x.shape
|
| x = self.projection(x)
|
| cls_tokens = self.cls_token.expand(b, -1, -1)
|
| x = torch.cat([cls_tokens, x], dim=1)
|
| x += self.positions[:x.shape[1], :]
|
| return x
|
|
|
| class MultiHeadAttention(nn.Module):
|
| def __init__(self, emb_size, num_heads, dropout):
|
| super().__init__()
|
| self.emb_size = emb_size
|
| self.num_heads = num_heads
|
| self.keys = nn.Linear(emb_size, emb_size)
|
| self.queries = nn.Linear(emb_size, emb_size)
|
| self.values = nn.Linear(emb_size, emb_size)
|
| self.att_drop = nn.Dropout(dropout)
|
| self.projection = nn.Linear(emb_size, emb_size)
|
|
|
| def forward(self, x, mask=None):
|
| queries = rearrange(self.queries(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
| keys = rearrange(self.keys(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
| values = rearrange(self.values(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
| energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
| scaling = self.emb_size ** (1 / 2)
|
| att = F.softmax(energy / scaling, dim=-1)
|
| att = self.att_drop(att)
|
| out = torch.einsum('bhal, bhlv -> bhav ', att, values)
|
| out = rearrange(out, 'b h n d -> b n (h d)')
|
| out = self.projection(out)
|
| return out
|
|
|
| class ResidualAdd(nn.Module):
|
| def __init__(self, fn):
|
| super().__init__()
|
| self.fn = fn
|
| def forward(self, x, **kwargs):
|
| res = x
|
| x = self.fn(x, **kwargs)
|
| x += res
|
| return x
|
|
|
| class FeedForwardBlock(nn.Sequential):
|
| def __init__(self, emb_size, expansion, dropout):
|
| super().__init__(
|
| nn.Linear(emb_size, expansion * emb_size),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(expansion * emb_size, emb_size),
|
| )
|
|
|
| class TransformerEncoderBlock(nn.Sequential):
|
| def __init__(self, emb_size, num_heads=4, drop_p=0.5, forward_expansion=4, forward_drop_p=0.5):
|
| super().__init__(
|
| ResidualAdd(nn.Sequential(
|
| nn.LayerNorm(emb_size),
|
| MultiHeadAttention(emb_size, num_heads, drop_p),
|
| nn.Dropout(drop_p)
|
| )),
|
| ResidualAdd(nn.Sequential(
|
| nn.LayerNorm(emb_size),
|
| FeedForwardBlock(emb_size, expansion=forward_expansion, dropout=forward_drop_p),
|
| nn.Dropout(drop_p)
|
| ))
|
| )
|
|
|
| class TransformerEncoder(nn.Sequential):
|
| def __init__(self, depth, emb_size):
|
| super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
|
|
|
| class ClassificationHead(nn.Sequential):
|
| def __init__(self, emb_size, n_classes):
|
| super().__init__()
|
| self.clshead = nn.Sequential(
|
| Reduce('b n e -> b e', reduction='mean'),
|
| nn.LayerNorm(emb_size),
|
| nn.Linear(emb_size, n_classes)
|
| )
|
| def forward(self, x):
|
| return self.clshead(x)
|
|
|
| class EEGConformer(nn.Module):
|
| def __init__(self, n_classes=3, channels=64, time_points=1000):
|
| super().__init__()
|
| self.conv1 = nn.Conv2d(1, 40, (1, 25), (1, 1))
|
|
|
| k2 = channels if channels < 40 else 40
|
| self.conv2 = nn.Conv2d(40, 40, (channels, 1), (1, 1))
|
| self.batchnorm1 = nn.BatchNorm2d(40)
|
| self.avgpool1 = nn.AvgPool2d((1, 75), (1, 15))
|
| self.flatten = nn.Flatten()
|
|
|
|
|
| self.projection = nn.LazyLinear(40)
|
|
|
| self.transformer = TransformerEncoder(depth=4, emb_size=40)
|
| self.classification = ClassificationHead(emb_size=40, n_classes=n_classes)
|
|
|
| def forward(self, x):
|
| if x.dim() == 3:
|
| x = x.unsqueeze(1)
|
| x = self.conv1(x)
|
| x = self.conv2(x)
|
| x = self.batchnorm1(x)
|
| x = F.elu(x)
|
| x = self.avgpool1(x)
|
| x = F.dropout(x, 0.5)
|
| x = rearrange(x, 'b e h w -> b (h w) e')
|
| x = self.projection(x)
|
| x = self.transformer(x)
|
| x = self.classification(x)
|
| return x
|
|
|