File size: 3,325 Bytes
11cc6a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# --------------------------------------------------------
# EEG-DINO: Learning EEG Foundation Models via Hierarchical Self-Distillation
# Based on BEiT-v2, timm, DeiT, DINO v2, LaBraM and CBraMod code bases
# https://github.com/microsoft/unilm/tree/master/beitv2
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dinov2
# https://github.com/935963004/LaBraM
# https://github.com/wjq-learning/CBraMod
# ---------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.time_encoding = nn.Sequential(
            nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2),
                      groups=d_model),
        )

        self.proj_in = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(1, 49), stride=(1, 25), padding=(0, 24)),
            nn.GroupNorm(8, 64),
            nn.GELU(),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
            nn.GroupNorm(8, 128),
            nn.GELU(),

            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
            nn.GroupNorm(8, 64),
            nn.GELU(),
        )
        self.spectral_proj = nn.Sequential(
            nn.Linear(101, d_model),
            nn.Dropout(0.1),
        )

        self.num_channels = 19
        self.channel_embedding = nn.Linear(self.num_channels, d_model)

    def forward(self, x):
        bz, ch_num, patch_num, patch_size = x.shape
        channel_in = torch.arange(self.num_channels+1).cuda()

        x = x.contiguous().view(bz, 1, ch_num * patch_num, patch_size)
        patch_emb = self.proj_in(x)
        patch_emb = patch_emb.permute(0, 2, 1, 3).contiguous().view(bz, ch_num, patch_num, self.d_model)

        x = x.contiguous().view(bz*ch_num*patch_num, patch_size)
        spectral = torch.fft.rfft(x, dim=-1, norm='forward')
        spectral = torch.abs(spectral).contiguous().view(bz, ch_num, patch_num, 101)
        spectral_emb = self.spectral_proj(spectral)

        patch_emb = patch_emb + spectral_emb

        channel_embeddings = []
        start_idx = 0

        group_channels = channel_in[start_idx:start_idx + ch_num]
        group_one_hot = F.one_hot(group_channels, num_classes=self.num_channels).float()
        group_emb = self.channel_embedding(group_one_hot)
        group_emb = group_emb.unsqueeze(0).unsqueeze(2)  # [1, ch_num, 1, d_model]
        group_emb = group_emb.expand(bz, -1, patch_num, -1)
        channel_embeddings.append(group_emb)
        start_idx += ch_num
        
        channel_pos = torch.cat(channel_embeddings, dim=0)  # [total_bz, ch_num, patch_num, d_model]

        patch_emb = patch_emb + channel_pos

        time_embedding = self.time_encoding(patch_emb.permute(0, 3, 1, 2))
        time_embedding = time_embedding.permute(0, 2, 3, 1)

        patch_emb = patch_emb + time_embedding

        return patch_emb