File size: 2,591 Bytes
48d00fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from tqdm import tqdm

from mae import MAE

# Use cm positions from mne.get_montage() - these are only included for reference.
POSITIONS = {
    "Fp1": (-3.09, 11.46, 2.79),
    "Fp2": (2.84, 11.53, 2.77),
    "F3": (-5.18, 8.67, 7.87),
    "F4": (5.03, 8.74, 7.73),
    "F7": (-7.19, 7.31, 2.58),
    "F8": (7.14, 7.45, 2.51),
    "T3": (-8.60, 1.49, 3.12),
    "T4": (8.33, 1.53, 3.10),
    "C3": (-6.71, 2.34, 10.45),
    "C4": (6.53, 2.36, 10.37),
    "T5": (-8.77, 1.29, -0.77),
    "T6": (8.37, 1.17, -0.77),
    "P3": (-5.50, -4.42, 9.99),
    "P4": (5.36, -4.43, 10.05),
    "O1": (-3.16, -8.06, 5.48),
    "O2": (2.77, -8.05, 5.47),
    "Fz": (-0.12, 9.33, 10.26),
    "Cz": (-0.14, 2.76, 14.02),
    "Pz": (-0.17, -4.52, 12.67),
    "A2": (8.39, 0.20, -2.69),
}


class MANAS1(nn.Module):
    def __init__(self, checkpoint_path, num_classes=2, flat_dim=512):
        super().__init__()

        print(f"Loading checkpoint from {checkpoint_path}...")
        ckpt = torch.load(checkpoint_path, map_location="cpu")

        self.mae = MAE(fs=200, embed_dim=512, encoder_depth=12, encoder_heads=8, decoder_depth=4, decoder_heads=8, mask_ratio=0.55)
        self.mae.load_state_dict(ckpt["model_state_dict"])

        self.patch_embed = self.mae.patch_embed
        self.pos_enc = self.mae.pos_enc
        self.encoder = self.mae.encoder
        self.patch_size = self.mae.patch_size
        self.step = self.mae.step

        self.flat_dim = flat_dim

        # # The Head
        # self.final_layer = nn.Sequential(
        #     nn.Flatten(),
        #     nn.RMSNorm(self.flat_dim),  # Tutorial uses RMSNorm
        #     nn.Dropout(0.1),
        #     nn.Linear(self.flat_dim, num_classes),
        # )

    def prepare_coords(self, xyz, num_patches):
        B, C, _ = xyz.shape
        device = xyz.device
        time_idx = torch.arange(num_patches, device=device).float()
        spat = xyz.unsqueeze(2).expand(-1, -1, num_patches, -1)
        time = time_idx.view(1, 1, num_patches, 1).expand(B, C, -1, -1)
        return torch.cat([spat, time], dim=-1).flatten(1, 2)

    def forward(self, x, pos):
        patches = x.unfold(-1, self.patch_size, self.step)
        num_patches = patches.shape[2]

        tokens = self.patch_embed.linear(patches).flatten(1, 2)

        coords = self.prepare_coords(pos, num_patches)
        pe = self.pos_enc(coords)

        x_enc = tokens + pe
        latents, _ = self.encoder(x_enc)

        # add final layer for classification
        return latents