File size: 2,148 Bytes
11cc6a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# --------------------------------------------------------
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
# https://github.com/microsoft/unilm/tree/master/beitv2
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dinov2
# https://github.com/935963004/LaBraM
# https://github.com/wjq-learning/CBraMod
# ---------------------------------------------------------
import torch
import torch.nn as nn
from monai.networks.nets.swin_unetr import *
import torch.nn.functional as F
from models.embedding_small import PatchEmbedding
from models.transformer import TransformerEncoderLayer

class EEGEncoder(nn.Module):
    def __init__(self, args):
        super(EEGEncoder, self).__init__()
        self.patch_embedding = PatchEmbedding(
            d_model=args.feature_size
        )
        
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(
                d_model=args.feature_size, 
                nhead=args.num_heads,
                dim_feedforward=args.dim_feedforward,
            ) for _ in range(args.num_layers)
        ])
        
        self.global_tokens = nn.Parameter(
            torch.randn(1, args.num_global_tokens, args.feature_size)
        )
        self.global_token_layer = args.global_token_layer
        
    def forward(self, x_in):
        B, C, P, L = x_in.shape
        if hasattr(self.patch_embedding, 'in_dim'):
            self.patch_embedding.in_dim = C
        
        # 1. Patch Embedding
        x = self.patch_embedding(x_in)  # [B, C, P, D]
        b = x.shape[0]
        
        x = x.reshape(b, -1, x.shape[-1])  # [B, C*P, D]

        global_tokens = self.global_tokens.expand(b, -1, -1)  # [B, num_global, D]
        
        for i, encoder_layer in enumerate(self.encoder_layers):
            x = encoder_layer(x)
            if i + 1 == self.global_token_layer:
                x = torch.cat([global_tokens, x], dim=1)  # [B, num_global+C*P, D]
        
        return x