eeg-cognitive-load / src /model.py
dodo-2100's picture
Upload folder using huggingface_hub
2afe0cd verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
class PatchEmbedding(nn.Module):
def __init__(self, emb_size=40, patch_size=25, in_channels=1):
super().__init__()
self.projection = nn.Sequential(
nn.Conv2d(in_channels, emb_size, kernel_size=(1, patch_size), stride=(1, patch_size)),
Rearrange('b e (h) (w) -> b (h w) e'),
)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
self.positions = nn.Parameter(torch.randn(2000, emb_size))
def forward(self, x):
b, _, _, _ = x.shape
x = self.projection(x)
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x += self.positions[:x.shape[1], :]
return x
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads, dropout):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x, mask=None):
queries = rearrange(self.queries(x), 'b n (h d) -> b h n d', h=self.num_heads)
keys = rearrange(self.keys(x), 'b n (h d) -> b h n d', h=self.num_heads)
values = rearrange(self.values(x), 'b n (h d) -> b h n d', h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
scaling = self.emb_size ** (1 / 2)
att = F.softmax(energy / scaling, dim=-1)
att = self.att_drop(att)
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.projection(out)
return out
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion, dropout):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(expansion * emb_size, emb_size),
)
class TransformerEncoderBlock(nn.Sequential):
def __init__(self, emb_size, num_heads=4, drop_p=0.5, forward_expansion=4, forward_drop_p=0.5):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, num_heads, drop_p),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(emb_size, expansion=forward_expansion, dropout=forward_drop_p),
nn.Dropout(drop_p)
))
)
class TransformerEncoder(nn.Sequential):
def __init__(self, depth, emb_size):
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size, n_classes):
super().__init__()
self.clshead = nn.Sequential(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes)
)
def forward(self, x):
return self.clshead(x)
class EEGConformer(nn.Module):
def __init__(self, n_classes=3, channels=64, time_points=1000):
super().__init__()
self.conv1 = nn.Conv2d(1, 40, (1, 25), (1, 1))
# Ensure kernel size for dim 2 doesn't exceed channels
k2 = channels if channels < 40 else 40
self.conv2 = nn.Conv2d(40, 40, (channels, 1), (1, 1))
self.batchnorm1 = nn.BatchNorm2d(40)
self.avgpool1 = nn.AvgPool2d((1, 75), (1, 15))
self.flatten = nn.Flatten()
# Calculate embedding dim linearly or use LazyLinear
self.projection = nn.LazyLinear(40)
self.transformer = TransformerEncoder(depth=4, emb_size=40)
self.classification = ClassificationHead(emb_size=40, n_classes=n_classes)
def forward(self, x):
if x.dim() == 3:
x = x.unsqueeze(1)
x = self.conv1(x)
x = self.conv2(x)
x = self.batchnorm1(x)
x = F.elu(x)
x = self.avgpool1(x)
x = F.dropout(x, 0.5)
x = rearrange(x, 'b e h w -> b (h w) e')
x = self.projection(x)
x = self.transformer(x)
x = self.classification(x)
return x