manas-1 / modelclass.py
neurodx-labs's picture
Upload 3 files
48d00fd verified
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from tqdm import tqdm
from mae import MAE
# Use cm positions from mne.get_montage() - these are only included for reference.
POSITIONS = {
"Fp1": (-3.09, 11.46, 2.79),
"Fp2": (2.84, 11.53, 2.77),
"F3": (-5.18, 8.67, 7.87),
"F4": (5.03, 8.74, 7.73),
"F7": (-7.19, 7.31, 2.58),
"F8": (7.14, 7.45, 2.51),
"T3": (-8.60, 1.49, 3.12),
"T4": (8.33, 1.53, 3.10),
"C3": (-6.71, 2.34, 10.45),
"C4": (6.53, 2.36, 10.37),
"T5": (-8.77, 1.29, -0.77),
"T6": (8.37, 1.17, -0.77),
"P3": (-5.50, -4.42, 9.99),
"P4": (5.36, -4.43, 10.05),
"O1": (-3.16, -8.06, 5.48),
"O2": (2.77, -8.05, 5.47),
"Fz": (-0.12, 9.33, 10.26),
"Cz": (-0.14, 2.76, 14.02),
"Pz": (-0.17, -4.52, 12.67),
"A2": (8.39, 0.20, -2.69),
}
class MANAS1(nn.Module):
def __init__(self, checkpoint_path, num_classes=2, flat_dim=512):
super().__init__()
print(f"Loading checkpoint from {checkpoint_path}...")
ckpt = torch.load(checkpoint_path, map_location="cpu")
self.mae = MAE(fs=200, embed_dim=512, encoder_depth=12, encoder_heads=8, decoder_depth=4, decoder_heads=8, mask_ratio=0.55)
self.mae.load_state_dict(ckpt["model_state_dict"])
self.patch_embed = self.mae.patch_embed
self.pos_enc = self.mae.pos_enc
self.encoder = self.mae.encoder
self.patch_size = self.mae.patch_size
self.step = self.mae.step
self.flat_dim = flat_dim
# # The Head
# self.final_layer = nn.Sequential(
# nn.Flatten(),
# nn.RMSNorm(self.flat_dim), # Tutorial uses RMSNorm
# nn.Dropout(0.1),
# nn.Linear(self.flat_dim, num_classes),
# )
def prepare_coords(self, xyz, num_patches):
B, C, _ = xyz.shape
device = xyz.device
time_idx = torch.arange(num_patches, device=device).float()
spat = xyz.unsqueeze(2).expand(-1, -1, num_patches, -1)
time = time_idx.view(1, 1, num_patches, 1).expand(B, C, -1, -1)
return torch.cat([spat, time], dim=-1).flatten(1, 2)
def forward(self, x, pos):
patches = x.unfold(-1, self.patch_size, self.step)
num_patches = patches.shape[2]
tokens = self.patch_embed.linear(patches).flatten(1, 2)
coords = self.prepare_coords(pos, num_patches)
pe = self.pos_enc(coords)
x_enc = tokens + pe
latents, _ = self.encoder(x_enc)
# add final layer for classification
return latents