# -------------------------------------------------------- # EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation # Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases # https://github.com/microsoft/unilm/tree/master/beitv2 # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/facebookresearch/deit/ # https://github.com/facebookresearch/dinov2 # https://github.com/935963004/LaBraM # https://github.com/wjq-learning/CBraMod # --------------------------------------------------------- import torch import torch.nn as nn import torch.nn.functional as F class PatchEmbedding(nn.Module): def __init__(self, d_model): super().__init__() self.d_model = d_model self.time_encoding = nn.Sequential( nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2), groups=d_model), ) self.proj_in = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)), nn.GroupNorm(8, 64), nn.GELU(), nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)), nn.GroupNorm(8, 128), nn.GELU(), nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)), nn.GroupNorm(8, 64), nn.GELU(), ) self.spectral_proj = nn.Sequential( nn.Linear(101, d_model), nn.Dropout(0.1), ) self.num_channels = 19 self.channel_embedding = nn.Linear(self.num_channels, d_model) def forward(self, x): bz, ch_num, patch_num, patch_size = x.shape channel_in = torch.arange(self.num_channels+1).cuda() x = x.contiguous().view(bz, 1, ch_num * patch_num, patch_size) patch_emb = self.proj_in(x) patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model) x = x.contiguous().view(bz*ch_num*patch_num, patch_size) spectral = torch.fft.rfft(x, dim=-1, norm='forward') spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101) spectral_emb = self.spectral_proj(spectral) patch_emb = patch_emb + spectral_emb channel_embeddings = [] start_idx = 0 group_channels = channel_in[start_idx:start_idx + ch_num] group_one_hot = F.one_hot(group_channels, num_classes=self.num_channels).float() group_emb = self.channel_embedding(group_one_hot) group_emb = group_emb.unsqueeze(0).unsqueeze(2) # [1, ch_num, 1, d_model] group_emb = group_emb.expand(bz, -1, patch_num, -1) channel_embeddings.append(group_emb) start_idx += ch_num channel_pos = torch.cat(channel_embeddings, dim=0) # [total_bz, ch_num, patch_num, d_model] patch_emb = patch_emb + channel_pos time_embedding = self.time_encoding(patch_emb.permute(0, 3, 1, 2)) time_embedding = time_embedding.permute(0, 2, 3, 1) patch_emb = patch_emb + time_embedding return patch_emb