import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange, Reduce class PatchEmbedding(nn.Module): def __init__(self, emb_size=40, patch_size=25, in_channels=1): super().__init__() self.projection = nn.Sequential( nn.Conv2d(in_channels, emb_size, kernel_size=(1, patch_size), stride=(1, patch_size)), Rearrange('b e (h) (w) -> b (h w) e'), ) self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size)) self.positions = nn.Parameter(torch.randn(2000, emb_size)) def forward(self, x): b, _, _, _ = x.shape x = self.projection(x) cls_tokens = self.cls_token.expand(b, -1, -1) x = torch.cat([cls_tokens, x], dim=1) x += self.positions[:x.shape[1], :] return x class MultiHeadAttention(nn.Module): def __init__(self, emb_size, num_heads, dropout): super().__init__() self.emb_size = emb_size self.num_heads = num_heads self.keys = nn.Linear(emb_size, emb_size) self.queries = nn.Linear(emb_size, emb_size) self.values = nn.Linear(emb_size, emb_size) self.att_drop = nn.Dropout(dropout) self.projection = nn.Linear(emb_size, emb_size) def forward(self, x, mask=None): queries = rearrange(self.queries(x), 'b n (h d) -> b h n d', h=self.num_heads) keys = rearrange(self.keys(x), 'b n (h d) -> b h n d', h=self.num_heads) values = rearrange(self.values(x), 'b n (h d) -> b h n d', h=self.num_heads) energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) scaling = self.emb_size ** (1 / 2) att = F.softmax(energy / scaling, dim=-1) att = self.att_drop(att) out = torch.einsum('bhal, bhlv -> bhav ', att, values) out = rearrange(out, 'b h n d -> b n (h d)') out = self.projection(out) return out class ResidualAdd(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): res = x x = self.fn(x, **kwargs) x += res return x class FeedForwardBlock(nn.Sequential): def __init__(self, emb_size, expansion, dropout): super().__init__( nn.Linear(emb_size, expansion * emb_size), nn.GELU(), nn.Dropout(dropout), nn.Linear(expansion * emb_size, emb_size), ) class TransformerEncoderBlock(nn.Sequential): def __init__(self, emb_size, num_heads=4, drop_p=0.5, forward_expansion=4, forward_drop_p=0.5): super().__init__( ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), MultiHeadAttention(emb_size, num_heads, drop_p), nn.Dropout(drop_p) )), ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), FeedForwardBlock(emb_size, expansion=forward_expansion, dropout=forward_drop_p), nn.Dropout(drop_p) )) ) class TransformerEncoder(nn.Sequential): def __init__(self, depth, emb_size): super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)]) class ClassificationHead(nn.Sequential): def __init__(self, emb_size, n_classes): super().__init__() self.clshead = nn.Sequential( Reduce('b n e -> b e', reduction='mean'), nn.LayerNorm(emb_size), nn.Linear(emb_size, n_classes) ) def forward(self, x): return self.clshead(x) class EEGConformer(nn.Module): def __init__(self, n_classes=3, channels=64, time_points=1000): super().__init__() self.conv1 = nn.Conv2d(1, 40, (1, 25), (1, 1)) # Ensure kernel size for dim 2 doesn't exceed channels k2 = channels if channels < 40 else 40 self.conv2 = nn.Conv2d(40, 40, (channels, 1), (1, 1)) self.batchnorm1 = nn.BatchNorm2d(40) self.avgpool1 = nn.AvgPool2d((1, 75), (1, 15)) self.flatten = nn.Flatten() # Calculate embedding dim linearly or use LazyLinear self.projection = nn.LazyLinear(40) self.transformer = TransformerEncoder(depth=4, emb_size=40) self.classification = ClassificationHead(emb_size=40, n_classes=n_classes) def forward(self, x): if x.dim() == 3: x = x.unsqueeze(1) x = self.conv1(x) x = self.conv2(x) x = self.batchnorm1(x) x = F.elu(x) x = self.avgpool1(x) x = F.dropout(x, 0.5) x = rearrange(x, 'b e h w -> b (h w) e') x = self.projection(x) x = self.transformer(x) x = self.classification(x) return x