EEG-DINO / models /embedding_medium.py
eegdino's picture
EEG-DINO
11cc6a7 verified
# --------------------------------------------------------
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
# https://github.com/microsoft/unilm/tree/master/beitv2
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dinov2
# https://github.com/935963004/LaBraM
# https://github.com/wjq-learning/CBraMod
# ---------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbedding(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.time_encoding = nn.Sequential(
nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2),
groups=d_model),
)
self.proj_in = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)),
nn.GroupNorm(8, 64),
nn.GELU(),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
nn.GroupNorm(8, 128),
nn.GELU(),
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
nn.GroupNorm(8, 64),
nn.GELU(),
)
self.spectral_proj = nn.Sequential(
nn.Linear(101, d_model),
nn.Dropout(0.1),
)
self.num_channels = 19
self.channel_embedding = nn.Linear(self.num_channels, d_model)
def forward(self, x):
bz, ch_num, patch_num, patch_size = x.shape
channel_in = torch.arange(self.num_channels+1).cuda()
x = x.contiguous().view(bz, 1, ch_num * patch_num, patch_size)
patch_emb = self.proj_in(x)
patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model)
x = x.contiguous().view(bz*ch_num*patch_num, patch_size)
spectral = torch.fft.rfft(x, dim=-1, norm='forward')
spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101)
spectral_emb = self.spectral_proj(spectral)
patch_emb = patch_emb + spectral_emb
channel_embeddings = []
start_idx = 0
group_channels = channel_in[start_idx:start_idx + ch_num]
group_one_hot = F.one_hot(group_channels, num_classes=self.num_channels).float()
group_emb = self.channel_embedding(group_one_hot)
group_emb = group_emb.unsqueeze(0).unsqueeze(2) # [1, ch_num, 1, d_model]
group_emb = group_emb.expand(bz, -1, patch_num, -1)
channel_embeddings.append(group_emb)
start_idx += ch_num
channel_pos = torch.cat(channel_embeddings, dim=0) # [total_bz, ch_num, patch_num, d_model]
patch_emb = patch_emb + channel_pos
time_embedding = self.time_encoding(patch_emb.permute(0, 3, 1, 2))
time_embedding = time_embedding.permute(0, 2, 3, 1)
patch_emb = patch_emb + time_embedding
return patch_emb