| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | from functools import partial |
| | from tqdm import tqdm |
| |
|
| | from mae import MAE |
| |
|
| | |
| | POSITIONS = { |
| | "Fp1": (-3.09, 11.46, 2.79), |
| | "Fp2": (2.84, 11.53, 2.77), |
| | "F3": (-5.18, 8.67, 7.87), |
| | "F4": (5.03, 8.74, 7.73), |
| | "F7": (-7.19, 7.31, 2.58), |
| | "F8": (7.14, 7.45, 2.51), |
| | "T3": (-8.60, 1.49, 3.12), |
| | "T4": (8.33, 1.53, 3.10), |
| | "C3": (-6.71, 2.34, 10.45), |
| | "C4": (6.53, 2.36, 10.37), |
| | "T5": (-8.77, 1.29, -0.77), |
| | "T6": (8.37, 1.17, -0.77), |
| | "P3": (-5.50, -4.42, 9.99), |
| | "P4": (5.36, -4.43, 10.05), |
| | "O1": (-3.16, -8.06, 5.48), |
| | "O2": (2.77, -8.05, 5.47), |
| | "Fz": (-0.12, 9.33, 10.26), |
| | "Cz": (-0.14, 2.76, 14.02), |
| | "Pz": (-0.17, -4.52, 12.67), |
| | "A2": (8.39, 0.20, -2.69), |
| | } |
| |
|
| |
|
| | class MANAS1(nn.Module): |
| | def __init__(self, checkpoint_path, num_classes=2, flat_dim=512): |
| | super().__init__() |
| |
|
| | print(f"Loading checkpoint from {checkpoint_path}...") |
| | ckpt = torch.load(checkpoint_path, map_location="cpu") |
| |
|
| | self.mae = MAE(fs=200, embed_dim=512, encoder_depth=12, encoder_heads=8, decoder_depth=4, decoder_heads=8, mask_ratio=0.55) |
| | self.mae.load_state_dict(ckpt["model_state_dict"]) |
| |
|
| | self.patch_embed = self.mae.patch_embed |
| | self.pos_enc = self.mae.pos_enc |
| | self.encoder = self.mae.encoder |
| | self.patch_size = self.mae.patch_size |
| | self.step = self.mae.step |
| |
|
| | self.flat_dim = flat_dim |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def prepare_coords(self, xyz, num_patches): |
| | B, C, _ = xyz.shape |
| | device = xyz.device |
| | time_idx = torch.arange(num_patches, device=device).float() |
| | spat = xyz.unsqueeze(2).expand(-1, -1, num_patches, -1) |
| | time = time_idx.view(1, 1, num_patches, 1).expand(B, C, -1, -1) |
| | return torch.cat([spat, time], dim=-1).flatten(1, 2) |
| |
|
| | def forward(self, x, pos): |
| | patches = x.unfold(-1, self.patch_size, self.step) |
| | num_patches = patches.shape[2] |
| |
|
| | tokens = self.patch_embed.linear(patches).flatten(1, 2) |
| |
|
| | coords = self.prepare_coords(pos, num_patches) |
| | pe = self.pos_enc(coords) |
| |
|
| | x_enc = tokens + pe |
| | latents, _ = self.encoder(x_enc) |
| |
|
| | |
| | return latents |