|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from monai.networks.nets.swin_unetr import * |
|
|
import torch.nn.functional as F |
|
|
from models.embedding_small import PatchEmbedding |
|
|
from models.transformer import TransformerEncoderLayer |
|
|
|
|
|
class EEGEncoder(nn.Module): |
|
|
def __init__(self, args): |
|
|
super(EEGEncoder, self).__init__() |
|
|
self.patch_embedding = PatchEmbedding( |
|
|
d_model=args.feature_size |
|
|
) |
|
|
|
|
|
self.encoder_layers = nn.ModuleList([ |
|
|
TransformerEncoderLayer( |
|
|
d_model=args.feature_size, |
|
|
nhead=args.num_heads, |
|
|
dim_feedforward=args.dim_feedforward, |
|
|
) for _ in range(args.num_layers) |
|
|
]) |
|
|
|
|
|
self.global_tokens = nn.Parameter( |
|
|
torch.randn(1, args.num_global_tokens, args.feature_size) |
|
|
) |
|
|
self.global_token_layer = args.global_token_layer |
|
|
|
|
|
def forward(self, x_in): |
|
|
B, C, P, L = x_in.shape |
|
|
if hasattr(self.patch_embedding, 'in_dim'): |
|
|
self.patch_embedding.in_dim = C |
|
|
|
|
|
|
|
|
x = self.patch_embedding(x_in) |
|
|
b = x.shape[0] |
|
|
|
|
|
x = x.reshape(b, -1, x.shape[-1]) |
|
|
|
|
|
global_tokens = self.global_tokens.expand(b, -1, -1) |
|
|
|
|
|
for i, encoder_layer in enumerate(self.encoder_layers): |
|
|
x = encoder_layer(x) |
|
|
if i + 1 == self.global_token_layer: |
|
|
x = torch.cat([global_tokens, x], dim=1) |
|
|
|
|
|
return x |
|
|
|