# -------------------------------------------------------- # EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation # Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases # https://github.com/microsoft/unilm/tree/master/beitv2 # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/facebookresearch/deit/ # https://github.com/facebookresearch/dinov2 # https://github.com/935963004/LaBraM # https://github.com/wjq-learning/CBraMod # --------------------------------------------------------- import torch import torch.nn as nn from monai.networks.nets.swin_unetr import * import torch.nn.functional as F from models.embedding_small import PatchEmbedding from models.transformer import TransformerEncoderLayer class EEGEncoder(nn.Module): def __init__(self, args): super(EEGEncoder, self).__init__() self.patch_embedding = PatchEmbedding( d_model=args.feature_size ) self.encoder_layers = nn.ModuleList([ TransformerEncoderLayer( d_model=args.feature_size, nhead=args.num_heads, dim_feedforward=args.dim_feedforward, ) for _ in range(args.num_layers) ]) self.global_tokens = nn.Parameter( torch.randn(1, args.num_global_tokens, args.feature_size) ) self.global_token_layer = args.global_token_layer def forward(self, x_in): B, C, P, L = x_in.shape if hasattr(self.patch_embedding, 'in_dim'): self.patch_embedding.in_dim = C # 1. Patch Embedding x = self.patch_embedding(x_in) # [B, C, P, D] b = x.shape[0] x = x.reshape(b, -1, x.shape[-1]) # [B, C*P, D] global_tokens = self.global_tokens.expand(b, -1, -1) # [B, num_global, D] for i, encoder_layer in enumerate(self.encoder_layers): x = encoder_layer(x) if i + 1 == self.global_token_layer: x = torch.cat([global_tokens, x], dim=1) # [B, num_global+C*P, D] return x