neurodx-labs commited on
Commit
48d00fd
·
verified ·
1 Parent(s): b5001b2

Upload 3 files

Browse files
Files changed (3) hide show
  1. mae.py +357 -0
  2. manas1.pt +3 -0
  3. modelclass.py +81 -0
mae.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class PatchEmbed(nn.Module):
7
+ def __init__(self, fs: int = 200, patch_seconds: float = 1.0, overlap_seconds: float = 0.1, embed_dim: int = 512):
8
+ super().__init__()
9
+
10
+ self.patch_size = int(round(patch_seconds * fs))
11
+ self.overlap_size = int(round(overlap_seconds * fs))
12
+
13
+ self.step = self.patch_size - self.overlap_size
14
+
15
+ self.linear = nn.Linear(self.patch_size, embed_dim, bias=False)
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ patches = x.unfold(dimension=-1, size=self.patch_size, step=self.step)
19
+ return self.linear(patches)
20
+
21
+
22
+ class PosEnc(nn.Module):
23
+ def __init__(self, n_freqs: int = 4, embed_dim: int = 512):
24
+ super().__init__()
25
+
26
+ freqs = torch.linspace(1.0, 10.0, n_freqs)
27
+ self.register_buffer("freq_matrix", torch.cartesian_prod(freqs, freqs, freqs, freqs).transpose(1, 0))
28
+
29
+ fourier_features_dim = 2 * (n_freqs**4)
30
+
31
+ self.fourier_linear = nn.Linear(fourier_features_dim, embed_dim, bias=False)
32
+ self.learned_linear = nn.Sequential(nn.Linear(4, embed_dim, bias=False), nn.GELU(), nn.LayerNorm(embed_dim))
33
+
34
+ self.final_norm = nn.LayerNorm(embed_dim)
35
+
36
+ def forward(self, coords: torch.Tensor):
37
+ phases = torch.matmul(coords, self.freq_matrix)
38
+
39
+ fourier_features = torch.cat([torch.sin(phases), torch.cos(phases)], -1)
40
+ fourier_emb = self.fourier_linear(fourier_features)
41
+
42
+ learned_emb = self.learned_linear(coords)
43
+
44
+ return self.final_norm(fourier_emb + learned_emb)
45
+
46
+
47
+ class TransformerBlock(nn.Module):
48
+ def __init__(self, embed_dim: int, heads: int, dropout: float = 0.0):
49
+ super().__init__()
50
+
51
+ assert embed_dim % heads == 0, "dim must be divisible by heads"
52
+
53
+ self.pre_attn_norm = nn.LayerNorm(embed_dim)
54
+ self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=heads, dropout=dropout, batch_first=True)
55
+
56
+ self.pre_ffn_norm = nn.LayerNorm(embed_dim)
57
+ self.ffn = nn.Sequential(nn.Linear(embed_dim, 4 * embed_dim), nn.GELU(), nn.Linear(4 * embed_dim, embed_dim))
58
+
59
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
60
+ attn_in = self.pre_attn_norm(x)
61
+
62
+ attn_out, _ = self.attn(attn_in, attn_in, attn_in)
63
+ x = x + attn_out
64
+
65
+ ffn_in = self.pre_ffn_norm(x)
66
+
67
+ ffn_out = self.ffn(ffn_in)
68
+ x = x + ffn_out
69
+
70
+ return x, ffn_out
71
+
72
+
73
+ class TransformerEncoderDecoder(nn.Module):
74
+ def __init__(self, embed_dim: int = 512, depth: int = 16, heads: int = 8):
75
+ super().__init__()
76
+
77
+ self.layers = nn.ModuleList([TransformerBlock(embed_dim, heads) for _ in range(depth)])
78
+ self.final_norm = nn.LayerNorm(embed_dim)
79
+
80
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
81
+ intermediate = []
82
+
83
+ for layer in self.layers:
84
+ x, ffn_out = layer(x)
85
+ intermediate.append(ffn_out)
86
+
87
+ return self.final_norm(x), intermediate
88
+
89
+
90
+ class MAEDecoder(nn.Module):
91
+ def __init__(self, embed_dim: int = 512, decoder_depth: int = 4, decoder_heads: int = 8, patch_size: int = 200):
92
+ super().__init__()
93
+
94
+ # 1. The Mask Token (The "Gray Tile")
95
+ # A learnable vector that replaces every missing patch
96
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
97
+ nn.init.normal_(self.mask_token, std=0.02)
98
+
99
+ # 2. The Decoder Transformer (Reuse your Encoder logic)
100
+ # It's lighter (fewer layers) than the main Encoder
101
+ self.decoder = TransformerEncoderDecoder(embed_dim=embed_dim, depth=decoder_depth, heads=decoder_heads)
102
+
103
+ # 3. The Prediction Head
104
+ # Projects Vector (512) -> Raw Signal (200)
105
+ self.predict = nn.Linear(embed_dim, patch_size, bias=True)
106
+
107
+ def forward(self, x_visible: torch.Tensor, pos_enc: nn.Module, coords: torch.Tensor, mask: torch.Tensor):
108
+ B, N_Total, D = coords.shape[0], coords.shape[1], x_visible.shape[-1]
109
+
110
+ # --- Step A: Fill Canvas with Mask Tokens ---
111
+ # Create a tensor of size (Batch, Total, Dim) filled with the mask token
112
+ x_full = self.mask_token.expand(B, N_Total, D).clone()
113
+
114
+ # --- Step B: Paste Visible Tokens ---
115
+ # Overwrite the mask tokens with the actual encoder output at the visible spots
116
+ for i in range(B):
117
+ # We use the boolean mask to select the "True" slots
118
+ x_full[i, mask[i]] = x_visible[i]
119
+
120
+ # --- Step C: Add Positional Encoding ---
121
+ # We call YOUR PosEnc class here.
122
+ # It takes coords (B, N_Total, 4) and returns (B, N_Total, Dim)
123
+ pos_emb = pos_enc(coords)
124
+
125
+ # Add GPS info to the tokens
126
+ x_full = x_full + pos_emb
127
+
128
+ # --- Step D: Decode ---
129
+ # Pass through the Transformer
130
+ # We ignore the intermediate outputs (the second return value) for now
131
+ x_decoded, _ = self.decoder(x_full)
132
+
133
+ # --- Step E: Predict ---
134
+ # (Batch, N_Total, 512) -> (Batch, N_Total, 200)
135
+ prediction = self.predict(x_decoded)
136
+
137
+ return prediction
138
+
139
+
140
+ def generate_mask(coords: torch.Tensor, mask_ratio: float = 0.55, spatial_radius: float = 3.0, temporal_radius: float = 3.0):
141
+ B, N, _ = coords.shape
142
+ device = coords.device
143
+
144
+ # Calculate exact number of tokens to hide
145
+ num_masked_target = int(mask_ratio * N)
146
+
147
+ # Start with all True (Visible)
148
+ mask = torch.ones(B, N, dtype=torch.bool, device=device)
149
+
150
+ for b in range(B):
151
+ spatial_coords = coords[b, :, :3]
152
+ temporal_coords = coords[b, :, 3]
153
+
154
+ # --- Phase 1: Block Masking Strategy ---
155
+ # Keep masking blocks until we meet or exceed the target
156
+ while (~mask[b]).sum() < num_masked_target:
157
+ # Pick random seed
158
+ seed_idx = torch.randint(0, N, (1,)).item()
159
+
160
+ # Calculate distances
161
+ seed_spatial = spatial_coords[seed_idx]
162
+ dists_spatial = torch.norm(spatial_coords - seed_spatial, dim=1)
163
+
164
+ seed_temporal = temporal_coords[seed_idx]
165
+ dists_temporal = torch.abs(temporal_coords - seed_temporal)
166
+
167
+ # Find block
168
+ in_block = (dists_spatial <= spatial_radius) & (dists_temporal <= temporal_radius)
169
+
170
+ # Mask this block (Set to False)
171
+ mask[b, in_block] = False
172
+
173
+ # --- Phase 2: Exact Count Enforcement ---
174
+ # We likely masked too many tokens. We must unmask the excess.
175
+
176
+ # Get indices of all tokens that are currently masked
177
+ masked_indices = torch.where(mask[b] == False)[0]
178
+ num_current_masked = len(masked_indices)
179
+
180
+ if num_current_masked > num_masked_target:
181
+ # We have excess. Randomly choose which ones to KEEP masked.
182
+ # Shuffle the masked indices
183
+ shuffled_indices = masked_indices[torch.randperm(num_current_masked)]
184
+
185
+ # The first 'num_masked_target' stay masked.
186
+ # The rest (excess) must be turned back to Visible (True).
187
+ excess_indices = shuffled_indices[num_masked_target:]
188
+
189
+ mask[b, excess_indices] = True
190
+
191
+ return mask
192
+
193
+
194
+ class MAE(nn.Module):
195
+ def __init__(
196
+ self,
197
+ # Data Params
198
+ fs: int = 200,
199
+ patch_seconds: float = 1.0,
200
+ overlap_seconds: float = 0.1,
201
+ # Model Params
202
+ embed_dim: int = 512,
203
+ encoder_depth: int = 12,
204
+ encoder_heads: int = 8,
205
+ decoder_depth: int = 4,
206
+ decoder_heads: int = 8,
207
+ # Training Params
208
+ mask_ratio: float = 0.55,
209
+ aux_loss_weight: float = 0.1,
210
+ ):
211
+ super().__init__()
212
+
213
+ self.embed_dim = embed_dim
214
+ self.mask_ratio = mask_ratio
215
+ self.aux_loss_weight = aux_loss_weight
216
+
217
+ # 1. Input Processing
218
+ self.patch_embed = PatchEmbed(fs, patch_seconds, overlap_seconds, embed_dim)
219
+
220
+ # We calculate patch_size and step from the component we just initialized
221
+ self.patch_size = self.patch_embed.patch_size
222
+ self.step = self.patch_embed.step
223
+
224
+ # 2. Positional Encoding (Shared between Encoder and Decoder)
225
+ self.pos_enc = PosEnc(n_freqs=4, embed_dim=embed_dim)
226
+
227
+ # 3. Encoder
228
+ self.encoder = TransformerEncoderDecoder(embed_dim=embed_dim, depth=encoder_depth, heads=encoder_heads)
229
+
230
+ # 4. Decoder (Main Reconstruction)
231
+ self.decoder = MAEDecoder(embed_dim=embed_dim, decoder_depth=decoder_depth, decoder_heads=decoder_heads, patch_size=self.patch_size)
232
+
233
+ # 5. Auxiliary Head (Global Token)
234
+ # We concatenate outputs from ALL encoder layers
235
+ self.aux_dim = encoder_depth * embed_dim
236
+
237
+ # A learned query vector to look at the encoder outputs
238
+ self.aux_query = nn.Parameter(torch.randn(1, 1, self.aux_dim))
239
+ nn.init.normal_(self.aux_query, std=0.02)
240
+
241
+ # Projection: (Depth * Dim) -> Dim
242
+ self.aux_linear = nn.Linear(self.aux_dim, embed_dim, bias=False)
243
+
244
+ # Reconstruction Head for Aux Task
245
+ self.aux_predict = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, self.patch_size))
246
+
247
+ def prepare_coords(self, xyz: torch.Tensor, num_patches: int):
248
+ B, C, _ = xyz.shape
249
+ device = xyz.device
250
+
251
+ # 2. Generate Time Indices (0, 1, 2, ... P-1)
252
+ time_idx = torch.arange(num_patches, device=device, dtype=torch.float32)
253
+
254
+ # 3. Expand Spatial Coords
255
+ # (B, C, 3) -> (B, C, 1, 3) -> (B, C, P, 3)
256
+ spat = xyz.unsqueeze(2).expand(-1, -1, num_patches, -1)
257
+
258
+ # 4. Expand Time Coords
259
+ # (P,) -> (1, 1, P, 1) -> (B, C, P, 1)
260
+ time = time_idx.view(1, 1, num_patches, 1).expand(B, C, -1, -1)
261
+
262
+ # 5. Concatenate -> (B, C, P, 4)
263
+ coords = torch.cat([spat, time], dim=-1)
264
+
265
+ # 6. Flatten to (B, N_Total, 4)
266
+ return coords.flatten(1, 2)
267
+
268
+ def forward(self, x: torch.Tensor, xyz: torch.Tensor):
269
+ B, _, _ = x.shape
270
+
271
+ # --- 1. Patchify & Embed ---
272
+ # patches: (B, C, P, PatchSize)
273
+ patches = x.unfold(-1, self.patch_size, self.step)
274
+ num_patches = patches.shape[2]
275
+
276
+ # tokens: (B, C, P, Dim)
277
+ tokens = self.patch_embed.linear(patches)
278
+
279
+ # Flatten to Sequence: (B, N_Total, Dim)
280
+ tokens_flat = tokens.flatten(1, 2)
281
+ patches_flat = patches.flatten(1, 2) # Target for loss
282
+
283
+ # --- 2. Prepare 4D Coordinates ---
284
+ coords = self.prepare_coords(xyz, num_patches)
285
+
286
+ # --- 3. Generate Mask ---
287
+ # Returns mask where counts are GUARANTEED to be equal across batch
288
+ mask = generate_mask(coords, mask_ratio=self.mask_ratio)
289
+
290
+ # --- 4. Prepare Encoder Input ---
291
+ # We need to extract only the visible tokens and stack them.
292
+ # Since counts are fixed, we can do this efficiently using boolean masking and reshaping.
293
+
294
+ # tokens_flat: (B, N_Total, D)
295
+ # mask: (B, N_Total)
296
+ # Result: (B, N_Vis, D)
297
+ # The .view() works because the number of Trues in mask is identical for every row b.
298
+ n_vis = mask[0].sum().item()
299
+
300
+ x_vis = tokens_flat[mask].view(B, n_vis, -1)
301
+ coords_vis = coords[mask].view(B, n_vis, -1)
302
+
303
+ # Add PE
304
+ pe_vis = self.pos_enc(coords_vis)
305
+ x_vis = x_vis + pe_vis
306
+
307
+ # --- 5. Encoder Forward ---
308
+ x_encoded, intermediates = self.encoder(x_vis)
309
+
310
+ # --- 6. Main Decoder Path ---
311
+ predictions_main = self.decoder(x_visible=x_encoded, pos_enc=self.pos_enc, coords=coords, mask=mask)
312
+
313
+ # --- 7. Auxiliary Path (Global Token) ---
314
+ # Concatenate all intermediate layers: (B, N_Vis, Depth*Dim)
315
+ aux_input = torch.cat(intermediates, dim=-1)
316
+
317
+ # Attention Pooling
318
+ # Score = Input @ Query.T
319
+ # (B, N_Vis, AuxDim) @ (1, 1, AuxDim).T -> (B, N_Vis, 1)
320
+ attn_scores = torch.matmul(aux_input, self.aux_query.transpose(1, 2))
321
+ attn_weights = F.softmax(attn_scores, dim=1)
322
+
323
+ # Pool: Sum(Weights * Input) -> (B, 1, AuxDim)
324
+ global_token = torch.sum(attn_weights * aux_input, dim=1, keepdim=True)
325
+
326
+ # Project to Embed Dim: (B, 1, Dim)
327
+ global_emb = self.aux_linear(global_token)
328
+
329
+ # Predict Masked Patches
330
+ # 1. Get coords of masked tokens
331
+ # Since mask is fixed count, we can reshape cleanly
332
+ n_masked = (~mask[0]).sum().item()
333
+ coords_masked = coords[~mask].view(B, n_masked, -1)
334
+
335
+ pe_masked = self.pos_enc(coords_masked)
336
+
337
+ # 2. Expand global token
338
+ global_expanded = global_emb.expand(-1, n_masked, -1)
339
+
340
+ # 3. Combine & Predict
341
+ aux_pred_in = global_expanded + pe_masked
342
+ predictions_aux = self.aux_predict(aux_pred_in)
343
+
344
+ # --- 8. Loss Calculation ---
345
+ # Target: Only the masked patches
346
+ target_masked = patches_flat[~mask].view(B, n_masked, -1)
347
+
348
+ # Main Loss (L1 on masked)
349
+ pred_main_masked = predictions_main[~mask].view(B, n_masked, -1)
350
+ loss_main = F.l1_loss(pred_main_masked, target_masked)
351
+
352
+ # Aux Loss (L1 on masked)
353
+ loss_aux = F.l1_loss(predictions_aux, target_masked)
354
+
355
+ total_loss = loss_main + self.aux_loss_weight * loss_aux
356
+
357
+ return total_loss, predictions_main, mask
manas1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb221351143c45e71ed478a5622a1ccf8f140b983a613f6f5875c862ae48ba76
3
+ size 653413200
modelclass.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from functools import partial
5
+ from tqdm import tqdm
6
+
7
+ from mae import MAE
8
+
9
+ # Use cm positions from mne.get_montage() - these are only included for reference.
10
+ POSITIONS = {
11
+ "Fp1": (-3.09, 11.46, 2.79),
12
+ "Fp2": (2.84, 11.53, 2.77),
13
+ "F3": (-5.18, 8.67, 7.87),
14
+ "F4": (5.03, 8.74, 7.73),
15
+ "F7": (-7.19, 7.31, 2.58),
16
+ "F8": (7.14, 7.45, 2.51),
17
+ "T3": (-8.60, 1.49, 3.12),
18
+ "T4": (8.33, 1.53, 3.10),
19
+ "C3": (-6.71, 2.34, 10.45),
20
+ "C4": (6.53, 2.36, 10.37),
21
+ "T5": (-8.77, 1.29, -0.77),
22
+ "T6": (8.37, 1.17, -0.77),
23
+ "P3": (-5.50, -4.42, 9.99),
24
+ "P4": (5.36, -4.43, 10.05),
25
+ "O1": (-3.16, -8.06, 5.48),
26
+ "O2": (2.77, -8.05, 5.47),
27
+ "Fz": (-0.12, 9.33, 10.26),
28
+ "Cz": (-0.14, 2.76, 14.02),
29
+ "Pz": (-0.17, -4.52, 12.67),
30
+ "A2": (8.39, 0.20, -2.69),
31
+ }
32
+
33
+
34
+ class MANAS1(nn.Module):
35
+ def __init__(self, checkpoint_path, num_classes=2, flat_dim=512):
36
+ super().__init__()
37
+
38
+ print(f"Loading checkpoint from {checkpoint_path}...")
39
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
40
+
41
+ self.mae = MAE(fs=200, embed_dim=512, encoder_depth=12, encoder_heads=8, decoder_depth=4, decoder_heads=8, mask_ratio=0.55)
42
+ self.mae.load_state_dict(ckpt["model_state_dict"])
43
+
44
+ self.patch_embed = self.mae.patch_embed
45
+ self.pos_enc = self.mae.pos_enc
46
+ self.encoder = self.mae.encoder
47
+ self.patch_size = self.mae.patch_size
48
+ self.step = self.mae.step
49
+
50
+ self.flat_dim = flat_dim
51
+
52
+ # # The Head
53
+ # self.final_layer = nn.Sequential(
54
+ # nn.Flatten(),
55
+ # nn.RMSNorm(self.flat_dim), # Tutorial uses RMSNorm
56
+ # nn.Dropout(0.1),
57
+ # nn.Linear(self.flat_dim, num_classes),
58
+ # )
59
+
60
+ def prepare_coords(self, xyz, num_patches):
61
+ B, C, _ = xyz.shape
62
+ device = xyz.device
63
+ time_idx = torch.arange(num_patches, device=device).float()
64
+ spat = xyz.unsqueeze(2).expand(-1, -1, num_patches, -1)
65
+ time = time_idx.view(1, 1, num_patches, 1).expand(B, C, -1, -1)
66
+ return torch.cat([spat, time], dim=-1).flatten(1, 2)
67
+
68
+ def forward(self, x, pos):
69
+ patches = x.unfold(-1, self.patch_size, self.step)
70
+ num_patches = patches.shape[2]
71
+
72
+ tokens = self.patch_embed.linear(patches).flatten(1, 2)
73
+
74
+ coords = self.prepare_coords(pos, num_patches)
75
+ pe = self.pos_enc(coords)
76
+
77
+ x_enc = tokens + pe
78
+ latents, _ = self.encoder(x_enc)
79
+
80
+ # add final layer for classification
81
+ return latents