EEG-DINO / models /eeg_encoder.py
eegdino's picture
EEG-DINO
11cc6a7 verified
# --------------------------------------------------------
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
# https://github.com/microsoft/unilm/tree/master/beitv2
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dinov2
# https://github.com/935963004/LaBraM
# https://github.com/wjq-learning/CBraMod
# ---------------------------------------------------------
import torch
import torch.nn as nn
from monai.networks.nets.swin_unetr import *
import torch.nn.functional as F
from models.embedding_small import PatchEmbedding
from models.transformer import TransformerEncoderLayer
class EEGEncoder(nn.Module):
def __init__(self, args):
super(EEGEncoder, self).__init__()
self.patch_embedding = PatchEmbedding(
d_model=args.feature_size
)
self.encoder_layers = nn.ModuleList([
TransformerEncoderLayer(
d_model=args.feature_size,
nhead=args.num_heads,
dim_feedforward=args.dim_feedforward,
) for _ in range(args.num_layers)
])
self.global_tokens = nn.Parameter(
torch.randn(1, args.num_global_tokens, args.feature_size)
)
self.global_token_layer = args.global_token_layer
def forward(self, x_in):
B, C, P, L = x_in.shape
if hasattr(self.patch_embedding, 'in_dim'):
self.patch_embedding.in_dim = C
# 1. Patch Embedding
x = self.patch_embedding(x_in) # [B, C, P, D]
b = x.shape[0]
x = x.reshape(b, -1, x.shape[-1]) # [B, C*P, D]
global_tokens = self.global_tokens.expand(b, -1, -1) # [B, num_global, D]
for i, encoder_layer in enumerate(self.encoder_layers):
x = encoder_layer(x)
if i + 1 == self.global_token_layer:
x = torch.cat([global_tokens, x], dim=1) # [B, num_global+C*P, D]
return x