Aluode commited on
Commit
258e525
·
verified ·
1 Parent(s): 1fc310c

Upload 3 files

Browse files
.gitattributes CHANGED
@@ -1 +1 @@
1
- *.pth filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -textpic.png filter=lfs diff=lfs merge=lfs -text
pic.png ADDED

Git LFS Details

  • SHA256: 5df29272c318160da88c03d7300967874d666078f2c31cd420acee0da789966a
  • Pointer size: 131 Bytes
  • Size of remote file: 526 kB
pkas_cal_trainer_gemini.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Calcium-Bridged Temporal EEG Decoder (V2 - Extended Time Window)
3
+ Integrates Phase-Calcium-Latent constraint satisfaction dynamics with EEG temporal windows.
4
+
5
+ Core Concept: Each EEG time window is processed by a constraint solver whose
6
+ calcium/W state carries over to initialize the next window, modeling how the brain
7
+ sequentially satisfies perceptual constraints.
8
+
9
+ V2 Update: The time window has been extended to 550ms based on ERP analysis from the
10
+ Alljoined1 paper, adding a 'CognitiveEvaluation' stage to capture late-stage
11
+ semantic and working memory signals.
12
+ """
13
+
14
+ import os
15
+ import json
16
+ import tkinter as tk
17
+ from tkinter import ttk, filedialog, messagebox
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.optim as optim
21
+ from torch.utils.data import Dataset, DataLoader
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+ import threading
25
+ import queue
26
+ from pathlib import Path
27
+ from collections import defaultdict
28
+ import matplotlib.pyplot as plt
29
+ from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
30
+
31
+ try:
32
+ from datasets import load_dataset
33
+ torch.backends.cudnn.benchmark = True
34
+ except ImportError as e:
35
+ print(f"Missing dependency: {e}")
36
+ exit()
37
+
38
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ EEG_SAMPLE_RATE = 512
40
+ BATCH_SIZE = 64
41
+
42
+ # --- CRITICAL CHANGE V2 ---
43
+ # Extended temporal windows to capture later cognitive processing (P300/N400/P600)
44
+ # This aligns the model's "attention span" with the neuroscience data.
45
+ TIME_WINDOWS = [
46
+ (50, 150, "EarlyVisual"), # Low-level visual constraints (P100)
47
+ (150, 250, "MidFeature"), # Mid-level binding (N170/P200)
48
+ (250, 350, "LateSemantic"), # High-level semantics (P300/N400 start)
49
+ (350, 550, "CognitiveEvaluation") # Deeper context, memory, final check (P300/P600)
50
+ ]
51
+
52
+ TARGET_CATEGORIES = {
53
+ 'elephant': 22, 'giraffe': 25, 'bear': 23, 'zebra': 24,
54
+ 'cow': 21, 'sheep': 20, 'horse': 19, 'dog': 18, 'cat': 17, 'bird': 16,
55
+ 'airplane': 5, 'train': 7, 'boat': 9, 'bus': 6, 'truck': 8,
56
+ 'motorcycle': 4, 'bicycle': 2, 'car': 3,
57
+ 'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13,
58
+ 'parking meter': 14, 'bench': 15,
59
+ }
60
+
61
+ CATEGORY_NAMES = {v: k for k, v in TARGET_CATEGORIES.items()}
62
+
63
+ class CalciumAttentionModule(nn.Module):
64
+ """
65
+ Phase-Calcium-Latent dynamics for one time window.
66
+ Models constraint satisfaction via neuromorphic oscillator dynamics.
67
+ """
68
+ def __init__(self, n_features, d_model=256):
69
+ super().__init__()
70
+ self.n_features = n_features
71
+ self.d_model = d_model
72
+
73
+ # Phase dynamics (Kuramoto-like)
74
+ self.phase_proj = nn.Linear(n_features, d_model)
75
+
76
+ # Calcium dynamics (gating/attention)
77
+ self.ca_gate = nn.Sequential(
78
+ nn.Linear(d_model, d_model // 2),
79
+ nn.Sigmoid()
80
+ )
81
+
82
+ # Latent coupling matrix (W) - learned constraint structure
83
+ self.W = nn.Parameter(torch.randn(d_model, d_model) * 0.01)
84
+
85
+ # Layer norm for stability
86
+ self.norm = nn.LayerNorm(d_model)
87
+
88
+ def forward(self, x, prev_ca=None, prev_W=None):
89
+ """
90
+ x: Input features [batch, n_features]
91
+ prev_ca: Previous window's calcium state [batch, d_model]
92
+ prev_W: Previous window's coupling matrix [d_model, d_model]
93
+
94
+ Returns: features, calcium_state, W_matrix
95
+ """
96
+ batch_size = x.size(0)
97
+
98
+ # Phase projection
99
+ phi = self.phase_proj(x) # [batch, d_model]
100
+
101
+ # Initialize or carry over calcium
102
+ if prev_ca is None:
103
+ ca = torch.zeros(batch_size, self.d_model, device=x.device)
104
+ else:
105
+ ca = prev_ca.clone()
106
+
107
+ # Initialize or carry over W (coupling structure)
108
+ W = self.W if prev_W is None else prev_W
109
+
110
+ # Calcium accumulation (coherence-based)
111
+ # High when features are aligned (low when conflicting)
112
+ coherence = torch.abs(torch.cos(phi[:, :, None] - phi[:, None, :]))
113
+ ca_update = torch.mean(coherence, dim=2) # [batch, d_model]
114
+ ca = ca * 0.95 + ca_update * 0.05 # Temporal integration
115
+
116
+ # Calcium-gated attention
117
+ ca_gate = self.ca_gate(ca) # [batch, d_model//2]
118
+
119
+ # Apply constraint coupling (W matrix)
120
+ # This is where "mutual constraint satisfaction" happens
121
+ coupled = torch.matmul(phi, W) # [batch, d_model]
122
+
123
+ # Gate the coupling by calcium (only attend where calcium is high)
124
+ ca_gate_full = torch.cat([ca_gate, ca_gate], dim=1) # Expand to d_model
125
+ features = coupled * ca_gate_full
126
+
127
+ # Normalize
128
+ features = self.norm(features + phi) # Residual connection
129
+
130
+ return features, ca, W
131
+
132
+
133
+ class TemporalConstraintEEGModel(nn.Module):
134
+ """
135
+ Sequential constraint satisfaction across EEG time windows.
136
+ Each window is a constraint solver whose state primes the next.
137
+ (Dynamically sized based on TIME_WINDOWS constant)
138
+ """
139
+ def __init__(self, n_channels=64, num_classes=len(TARGET_CATEGORIES)):
140
+ super().__init__()
141
+ self.n_channels = n_channels
142
+
143
+ # CNN feature extractors for each time window
144
+ self.window_encoders = nn.ModuleList([
145
+ self._build_cnn_encoder() for _ in TIME_WINDOWS
146
+ ])
147
+
148
+ # Calcium-attention modules for each window
149
+ self.ca_modules = nn.ModuleList([
150
+ CalciumAttentionModule(256, d_model=256) for _ in TIME_WINDOWS
151
+ ])
152
+
153
+ # --- CRITICAL CHANGE V2 ---
154
+ # The input layer is now automatically larger (256 * 4) because len(TIME_WINDOWS) is 4.
155
+ self.classifier = nn.Sequential(
156
+ nn.Linear(256 * len(TIME_WINDOWS), 512),
157
+ nn.BatchNorm1d(512),
158
+ nn.GELU(),
159
+ nn.Dropout(0.3),
160
+ nn.Linear(512, num_classes)
161
+ )
162
+
163
+ def _build_cnn_encoder(self):
164
+ """Simple CNN for one time window"""
165
+ return nn.Sequential(
166
+ nn.Conv1d(self.n_channels, 128, kernel_size=15, padding=7),
167
+ nn.BatchNorm1d(128),
168
+ nn.ELU(),
169
+ nn.MaxPool1d(2),
170
+ nn.Conv1d(128, 256, kernel_size=7, padding=3),
171
+ nn.BatchNorm1d(256),
172
+ nn.ELU(),
173
+ nn.AdaptiveAvgPool1d(1)
174
+ )
175
+
176
+ def forward(self, eeg_windows):
177
+ """
178
+ eeg_windows: List of tensors [batch, channels, timepoints] for each window
179
+
180
+ Returns: logits, calcium_states (for visualization/analysis)
181
+ """
182
+ batch_size = eeg_windows[0].size(0)
183
+
184
+ # Process windows sequentially with calcium/W carryover
185
+ window_features = []
186
+ ca_state = None
187
+ W_state = None
188
+ ca_history = []
189
+
190
+ for i, (encoder, ca_module, eeg_window) in enumerate(
191
+ zip(self.window_encoders, self.ca_modules, eeg_windows)
192
+ ):
193
+ # Extract CNN features
194
+ cnn_features = encoder(eeg_window).squeeze(-1) # [batch, 256]
195
+
196
+ # Apply constraint satisfaction dynamics
197
+ features, ca_state, W_state = ca_module(
198
+ cnn_features,
199
+ prev_ca=ca_state,
200
+ prev_W=W_state
201
+ )
202
+
203
+ window_features.append(features)
204
+ ca_history.append(ca_state.detach().cpu().numpy())
205
+
206
+ # Concatenate all window features
207
+ combined = torch.cat(window_features, dim=1)
208
+
209
+ # Final classification
210
+ logits = self.classifier(combined)
211
+
212
+ return logits, ca_history
213
+
214
+
215
+ class CalciumEEGDataset(Dataset):
216
+ """Dataset that provides EEG data split by time windows"""
217
+ def __init__(self, coco_path, annotations_path, split='train',
218
+ max_samples=None, trials_to_average=1):
219
+ self.coco_path = Path(coco_path)
220
+
221
+ # Load dataset
222
+ print(f"Loading Alljoined ({split})...")
223
+ self.dataset = load_dataset("Alljoined/05_125", split=split, streaming=False)
224
+
225
+ if max_samples:
226
+ self.dataset = self.dataset.select(range(min(int(max_samples), len(self.dataset))))
227
+
228
+ # Load COCO annotations
229
+ print(f"Loading COCO annotations...")
230
+ with open(annotations_path, 'r') as f:
231
+ coco_data = json.load(f)
232
+
233
+ self.image_categories = defaultdict(set)
234
+ for ann in coco_data['annotations']:
235
+ img_id = ann['image_id']
236
+ if ann['category_id'] in CATEGORY_NAMES:
237
+ self.image_categories[img_id].add(ann['category_id'])
238
+
239
+ # Pre-cache samples
240
+ print("Pre-caching EEG data...")
241
+ self.samples = []
242
+ for idx, sample in enumerate(self.dataset):
243
+ coco_id = sample['coco_id']
244
+ if coco_id in self.image_categories and len(self.image_categories[coco_id]) > 0:
245
+ label = torch.zeros(len(TARGET_CATEGORIES))
246
+ for cat_id in self.image_categories[coco_id]:
247
+ if cat_id in CATEGORY_NAMES:
248
+ cat_idx = list(TARGET_CATEGORIES.values()).index(cat_id)
249
+ label[cat_idx] = 1.0
250
+
251
+ if label.sum() > 0:
252
+ self.samples.append((idx, label))
253
+
254
+ print(f"Cached {len(self.samples)} samples")
255
+
256
+ def __len__(self):
257
+ return len(self.samples)
258
+
259
+ def __getitem__(self, idx):
260
+ sample_idx, label = self.samples[idx]
261
+ sample = self.dataset[sample_idx]
262
+
263
+ eeg_data = np.array(sample['EEG'], dtype=np.float32)
264
+
265
+ # Extract time windows
266
+ eeg_windows = []
267
+ for start_ms, end_ms, _ in TIME_WINDOWS:
268
+ start_idx = int((start_ms / 1000.0) * EEG_SAMPLE_RATE)
269
+ end_idx = int((end_ms / 1000.0) * EEG_SAMPLE_RATE)
270
+
271
+ if eeg_data.shape[1] >= end_idx:
272
+ window = eeg_data[:, start_idx:end_idx]
273
+ else:
274
+ window = eeg_data[:, start_idx:]
275
+ # Pad if needed
276
+ if window.shape[1] < (end_idx - start_idx):
277
+ pad_width = (end_idx - start_idx) - window.shape[1]
278
+ window = np.pad(window, ((0,0), (0, pad_width)), mode='edge')
279
+
280
+ # Normalize
281
+ window = (window - window.mean(axis=1, keepdims=True)) / \
282
+ (window.std(axis=1, keepdims=True) + 1e-8)
283
+
284
+ eeg_windows.append(torch.from_numpy(window).float())
285
+
286
+ return eeg_windows, label
287
+
288
+
289
+ class CalciumEEGTrainerGUI(tk.Tk):
290
+ def __init__(self):
291
+ super().__init__()
292
+ self.title("Calcium-Bridged Temporal EEG Decoder V2")
293
+ self.geometry("1200x850")
294
+
295
+ self.coco_path = ""
296
+ self.annotations_path = ""
297
+ self.train_thread = None
298
+ self.stop_flag = threading.Event()
299
+ self.log_queue = queue.Queue()
300
+
301
+ self.setup_gui()
302
+ self.process_logs()
303
+
304
+ def setup_gui(self):
305
+ # Title
306
+ title = tk.Label(self, text="Calcium-Bridged Temporal EEG Decoder (V2 - Extended Window)",
307
+ font=("Arial", 14, "bold"))
308
+ title.pack(pady=10)
309
+
310
+ info = tk.Label(self,
311
+ text="Sequential constraint satisfaction across 4 ERP time windows up to 550ms\n"
312
+ "Calcium/W state from early windows primes later windows",
313
+ fg="blue", font=("Arial", 9))
314
+ info.pack(pady=5)
315
+
316
+ # Paths
317
+ path_frame = ttk.LabelFrame(self, text="Dataset")
318
+ path_frame.pack(pady=5, padx=10, fill=tk.X)
319
+
320
+ tk.Label(path_frame, text="COCO:").grid(row=0, column=0, sticky=tk.W, padx=5, pady=3)
321
+ self.coco_var = tk.StringVar()
322
+ ttk.Entry(path_frame, textvariable=self.coco_var, width=50).grid(row=0, column=1, padx=5)
323
+ ttk.Button(path_frame, text="Browse", command=self.browse_coco).grid(row=0, column=2)
324
+
325
+ tk.Label(path_frame, text="Annotations:").grid(row=1, column=0, sticky=tk.W, padx=5, pady=3)
326
+ self.ann_var = tk.StringVar()
327
+ ttk.Entry(path_frame, textvariable=self.ann_var, width=50).grid(row=1, column=1, padx=5)
328
+ ttk.Button(path_frame, text="Browse", command=self.browse_ann).grid(row=1, column=2)
329
+
330
+ # Settings
331
+ settings_frame = ttk.LabelFrame(self, text="Training Settings")
332
+ settings_frame.pack(pady=5, padx=10, fill=tk.X)
333
+
334
+ tk.Label(settings_frame, text="Max Samples:").grid(row=0, column=0, padx=5)
335
+ self.max_var = tk.IntVar(value=3000)
336
+ tk.Spinbox(settings_frame, from_=1000, to=10000, increment=1000,
337
+ textvariable=self.max_var, width=10).grid(row=0, column=1)
338
+
339
+ tk.Label(settings_frame, text="Epochs:").grid(row=0, column=2, padx=5)
340
+ self.epochs_var = tk.IntVar(value=100)
341
+ tk.Spinbox(settings_frame, from_=50, to=500, increment=50,
342
+ textvariable=self.epochs_var, width=10).grid(row=0, column=3)
343
+
344
+ # --- CRITICAL CHANGE V2 ---
345
+ # Updated GUI to reflect the new 4-stage process.
346
+ windows_frame = ttk.LabelFrame(self, text="Constraint Satisfaction Stages")
347
+ windows_frame.pack(pady=5, padx=10, fill=tk.X)
348
+
349
+ for start, end, label in TIME_WINDOWS:
350
+ desc = {
351
+ "EarlyVisual": "Low-level visual features (edges, textures)",
352
+ "MidFeature": "Mid-level binding (parts, shapes)",
353
+ "LateSemantic": "High-level semantics (concepts, context)",
354
+ "CognitiveEvaluation": "Memory, context check, final decision"
355
+ }
356
+ tk.Label(windows_frame,
357
+ text=f"{label} ({start}-{end}ms): {desc[label]}",
358
+ font=("Courier", 9)).pack(anchor=tk.W, padx=10, pady=2)
359
+
360
+ # Buttons
361
+ btn_frame = tk.Frame(self)
362
+ btn_frame.pack(pady=10)
363
+
364
+ self.train_btn = tk.Button(btn_frame, text="Train Extended Model (V2)",
365
+ command=self.start_train,
366
+ bg="#4CAF50", fg="white", font=("Arial", 10, "bold"))
367
+ self.train_btn.pack(side=tk.LEFT, padx=5)
368
+
369
+ self.stop_btn = tk.Button(btn_frame, text="Stop",
370
+ command=self.stop_train,
371
+ bg="#f44336", fg="white",
372
+ state=tk.DISABLED)
373
+ self.stop_btn.pack(side=tk.LEFT, padx=5)
374
+
375
+ # Progress
376
+ self.progress = ttk.Progressbar(self, mode='determinate')
377
+ self.progress.pack(fill=tk.X, padx=10, pady=5)
378
+
379
+ # Log
380
+ log_frame = ttk.LabelFrame(self, text="Training Log")
381
+ log_frame.pack(pady=5, padx=10, fill=tk.BOTH, expand=True)
382
+
383
+ self.log_text = tk.Text(log_frame, height=20, bg='black', fg='lightgreen',
384
+ font=('Courier', 8))
385
+ self.log_text.pack(fill=tk.BOTH, expand=True)
386
+
387
+ def browse_coco(self):
388
+ path = filedialog.askdirectory()
389
+ if path:
390
+ self.coco_var.set(path)
391
+ self.coco_path = path
392
+
393
+ def browse_ann(self):
394
+ path = filedialog.askopenfilename(filetypes=[("JSON", "*.json")])
395
+ if path:
396
+ self.ann_var.set(path)
397
+ self.annotations_path = path
398
+
399
+ def log(self, msg):
400
+ self.log_queue.put(msg)
401
+
402
+ def process_logs(self):
403
+ try:
404
+ while not self.log_queue.empty():
405
+ msg = self.log_queue.get_nowait()
406
+ self.log_text.insert(tk.END, msg + "\n")
407
+ self.log_text.see(tk.END)
408
+ except queue.Empty:
409
+ pass
410
+ self.after(100, self.process_logs)
411
+
412
+ def start_train(self):
413
+ if not self.coco_path or not self.annotations_path:
414
+ messagebox.showerror("Error", "Select paths first")
415
+ return
416
+
417
+ self.stop_flag.clear()
418
+ self.train_btn.config(state=tk.DISABLED)
419
+ self.stop_btn.config(state=tk.NORMAL)
420
+
421
+ self.train_thread = threading.Thread(target=self._train_model, daemon=True)
422
+ self.train_thread.start()
423
+
424
+ def stop_train(self):
425
+ self.stop_flag.set()
426
+
427
+ def _train_model(self):
428
+ try:
429
+ self.log("="*70)
430
+ self.log("CALCIUM-BRIDGED TEMPORAL EEG DECODER (V2 - Extended Window)")
431
+ self.log("="*70)
432
+ self.log("\nConcept: Sequential constraint satisfaction across FOUR time windows")
433
+ self.log("Now capturing late-stage cognitive evaluation signals up to 550ms\n")
434
+
435
+ # Create dataset
436
+ dataset = CalciumEEGDataset(
437
+ self.coco_path,
438
+ self.annotations_path,
439
+ 'train',
440
+ self.max_var.get()
441
+ )
442
+
443
+ total = len(dataset)
444
+ train_size = int(0.8 * total)
445
+ val_size = total - train_size
446
+
447
+ train_set, val_set = torch.utils.data.random_split(
448
+ dataset, [train_size, val_size],
449
+ generator=torch.Generator().manual_seed(42)
450
+ )
451
+
452
+ self.log(f"Train: {train_size}, Val: {val_size}")
453
+
454
+ train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
455
+ shuffle=True, num_workers=0, pin_memory=True)
456
+ val_loader = DataLoader(val_set, batch_size=BATCH_SIZE,
457
+ shuffle=False, num_workers=0, pin_memory=True)
458
+
459
+ # Create model (will be automatically sized for 4 windows)
460
+ model = TemporalConstraintEEGModel().to(DEVICE)
461
+ self.log(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
462
+
463
+ optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
464
+ criterion = nn.BCEWithLogitsLoss()
465
+ scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)
466
+
467
+ best_val_loss = float('inf')
468
+
469
+ for epoch in range(self.epochs_var.get()):
470
+ if self.stop_flag.is_set():
471
+ break
472
+
473
+ # Train
474
+ model.train()
475
+ train_loss = 0
476
+ for eeg_windows, labels in train_loader:
477
+ if self.stop_flag.is_set():
478
+ break
479
+
480
+ eeg_windows = [w.to(DEVICE) for w in eeg_windows]
481
+ labels = labels.to(DEVICE)
482
+
483
+ optimizer.zero_grad()
484
+ logits, _ = model(eeg_windows)
485
+ loss = criterion(logits, labels)
486
+ loss.backward()
487
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
488
+ optimizer.step()
489
+
490
+ train_loss += loss.item()
491
+
492
+ # Validate
493
+ model.eval()
494
+ val_loss = 0
495
+ with torch.no_grad():
496
+ for eeg_windows, labels in val_loader:
497
+ eeg_windows = [w.to(DEVICE) for w in eeg_windows]
498
+ labels = labels.to(DEVICE)
499
+ logits, _ = model(eeg_windows)
500
+ loss = criterion(logits, labels)
501
+ val_loss += loss.item()
502
+
503
+ train_loss /= len(train_loader)
504
+ val_loss /= len(val_loader)
505
+
506
+ scheduler.step()
507
+
508
+ self.progress['value'] = ((epoch + 1) / self.epochs_var.get()) * 100
509
+
510
+ if epoch % 5 == 0:
511
+ self.log(f"Epoch {epoch+1}/{self.epochs_var.get()}: "
512
+ f"TrLoss={train_loss:.4f} ValLoss={val_loss:.4f}")
513
+
514
+ if val_loss < best_val_loss:
515
+ best_val_loss = val_loss
516
+ torch.save({
517
+ 'model_state_dict': model.state_dict(),
518
+ 'val_loss': val_loss,
519
+ 'epoch': epoch
520
+ }, "calcium_bridge_eeg_model_v2.pth") # Save as V2
521
+ if epoch % 5 == 0:
522
+ self.log(f" -> Saved V2 model (val_loss={val_loss:.4f})")
523
+
524
+ self.log("\n" + "="*70)
525
+ self.log("TRAINING COMPLETE")
526
+ self.log(f"Best Val Loss: {best_val_loss:.4f}")
527
+ self.log("="*70)
528
+
529
+ except Exception as e:
530
+ self.log(f"ERROR: {e}")
531
+ import traceback
532
+ self.log(traceback.format_exc())
533
+ finally:
534
+ self.train_btn.config(state=tk.NORMAL)
535
+ self.stop_btn.config(state=tk.DISABLED)
536
+
537
+
538
+ if __name__ == "__main__":
539
+ app = CalciumEEGTrainerGUI()
540
+ app.mainloop()
pkas_cal_viewer_gemini2.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Calcium-Bridge EEG Constraint Viewer (V2.1 - Fixed)
3
+ Visualizes how constraint satisfaction unfolds across four temporal windows up to 550ms.
4
+
5
+ Shows:
6
+ 1. Original COCO image
7
+ 2. EEG heatmaps for each of the 4 time windows
8
+ 3. Calcium "attention" evolution (what the model focuses on at each stage)
9
+ 4. Top predictions crystallizing across the 4 windows
10
+
11
+ V2.1 Fixes:
12
+ - Corrected 'figsize' argument placement during figure creation.
13
+ - Corrected colorbar creation to use the figure object directly, resolving warnings.
14
+ """
15
+
16
+ import os
17
+ import tkinter as tk
18
+ from tkinter import filedialog, messagebox, ttk
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from PIL import Image, ImageTk
24
+ import matplotlib.pyplot as plt
25
+ from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
26
+ import json
27
+ from pathlib import Path
28
+ from collections import defaultdict
29
+ import random
30
+
31
+ try:
32
+ from datasets import load_dataset
33
+ except ImportError:
34
+ print("Missing datasets library.")
35
+ exit()
36
+
37
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ EEG_SAMPLE_RATE = 512
39
+
40
+ TIME_WINDOWS = [
41
+ (50, 150, "EarlyVisual"),
42
+ (150, 250, "MidFeature"),
43
+ (250, 350, "LateSemantic"),
44
+ (350, 550, "CognitiveEvaluation")
45
+ ]
46
+
47
+ TARGET_CATEGORIES = {
48
+ 'elephant': 22, 'giraffe': 25, 'bear': 23, 'zebra': 24,
49
+ 'cow': 21, 'sheep': 20, 'horse': 19, 'dog': 18, 'cat': 17, 'bird': 16,
50
+ 'airplane': 5, 'train': 7, 'boat': 9, 'bus': 6, 'truck': 8,
51
+ 'motorcycle': 4, 'bicycle': 2, 'car': 3,
52
+ 'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13,
53
+ 'parking meter': 14, 'bench': 15,
54
+ }
55
+
56
+ CATEGORY_NAMES = {v: k for k, v in TARGET_CATEGORIES.items()}
57
+ TARGET_IDS = set(TARGET_CATEGORIES.values())
58
+ ALL_COCO_IDS = list(range(1, 91))
59
+ EXCLUDED_IDS = set(ALL_COCO_IDS) - TARGET_IDS
60
+
61
+
62
+ # === MODEL ARCHITECTURE (Must match V2 training code) ===
63
+
64
+ class CalciumAttentionModule(nn.Module):
65
+ def __init__(self, n_features, d_model=256):
66
+ super().__init__()
67
+ self.n_features = n_features
68
+ self.d_model = d_model
69
+ self.phase_proj = nn.Linear(n_features, d_model)
70
+ self.ca_gate = nn.Sequential(
71
+ nn.Linear(d_model, d_model // 2),
72
+ nn.Sigmoid()
73
+ )
74
+ self.W = nn.Parameter(torch.randn(d_model, d_model) * 0.01)
75
+ self.norm = nn.LayerNorm(d_model)
76
+
77
+ def forward(self, x, prev_ca=None, prev_W=None):
78
+ batch_size = x.size(0)
79
+ phi = self.phase_proj(x)
80
+
81
+ if prev_ca is None:
82
+ ca = torch.zeros(batch_size, self.d_model, device=x.device)
83
+ else:
84
+ ca = prev_ca.clone()
85
+
86
+ W = self.W if prev_W is None else prev_W
87
+
88
+ coherence = torch.abs(torch.cos(phi[:, :, None] - phi[:, None, :]))
89
+ ca_update = torch.mean(coherence, dim=2)
90
+ ca = ca * 0.95 + ca_update * 0.05
91
+
92
+ ca_gate = self.ca_gate(ca)
93
+ coupled = torch.matmul(phi, W)
94
+ ca_gate_full = torch.cat([ca_gate, ca_gate], dim=1)
95
+ features = coupled * ca_gate_full
96
+ features = self.norm(features + phi)
97
+
98
+ return features, ca, W
99
+
100
+
101
+ class TemporalConstraintEEGModel(nn.Module):
102
+ def __init__(self, n_channels=64, num_classes=len(TARGET_CATEGORIES)):
103
+ super().__init__()
104
+ self.n_channels = n_channels
105
+
106
+ self.window_encoders = nn.ModuleList([
107
+ self._build_cnn_encoder() for _ in TIME_WINDOWS
108
+ ])
109
+
110
+ self.ca_modules = nn.ModuleList([
111
+ CalciumAttentionModule(256, d_model=256) for _ in TIME_WINDOWS
112
+ ])
113
+
114
+ self.classifier = nn.Sequential(
115
+ nn.Linear(256 * len(TIME_WINDOWS), 512),
116
+ nn.BatchNorm1d(512),
117
+ nn.GELU(),
118
+ nn.Dropout(0.3),
119
+ nn.Linear(512, num_classes)
120
+ )
121
+
122
+ def _build_cnn_encoder(self):
123
+ return nn.Sequential(
124
+ nn.Conv1d(self.n_channels, 128, kernel_size=15, padding=7),
125
+ nn.BatchNorm1d(128),
126
+ nn.ELU(),
127
+ nn.MaxPool1d(2),
128
+ nn.Conv1d(128, 256, kernel_size=7, padding=3),
129
+ nn.BatchNorm1d(256),
130
+ nn.ELU(),
131
+ nn.AdaptiveAvgPool1d(1)
132
+ )
133
+
134
+ def forward(self, eeg_windows, return_intermediates=False):
135
+ batch_size = eeg_windows[0].size(0)
136
+
137
+ window_features, ca_history, W_history, window_logits_list = [], [], [], []
138
+ ca_state, W_state = None, None
139
+
140
+ for i, (encoder, ca_module, eeg_window) in enumerate(
141
+ zip(self.window_encoders, self.ca_modules, eeg_windows)
142
+ ):
143
+ cnn_features = encoder(eeg_window).squeeze(-1)
144
+ features, ca_state, W_state = ca_module(cnn_features, ca_state, W_state)
145
+
146
+ window_features.append(features)
147
+ if return_intermediates:
148
+ ca_history.append(ca_state.detach())
149
+ W_history.append(W_state.detach())
150
+
151
+ padded_features = window_features + [
152
+ torch.zeros_like(features) for _ in range(len(TIME_WINDOWS) - len(window_features))
153
+ ]
154
+ intermediate_logits = self.classifier(torch.cat(padded_features, dim=1))
155
+ window_logits_list.append(intermediate_logits.detach())
156
+
157
+ combined = torch.cat(window_features, dim=1)
158
+ logits = self.classifier(combined)
159
+
160
+ if return_intermediates:
161
+ return logits, ca_history, W_history, window_logits_list
162
+ return logits, ca_history
163
+
164
+
165
+ # === DATA LOADER ===
166
+ class FilteredTestDataset:
167
+ def __init__(self, annotations_path, max_samples=1000):
168
+ print("Loading and filtering test dataset...")
169
+ self.eeg_dataset = load_dataset("Alljoined/05_125", split='test', streaming=False).select(range(max_samples))
170
+ with open(annotations_path, 'r') as f:
171
+ coco_data = json.load(f)
172
+
173
+ image_annotations = defaultdict(set)
174
+ for ann in coco_data['annotations']:
175
+ image_annotations[ann['image_id']].add(ann['category_id'])
176
+
177
+ self.filtered_samples = []
178
+ for idx, sample in enumerate(self.eeg_dataset):
179
+ ann_ids = image_annotations.get(sample['coco_id'], set())
180
+ if not any(cat_id in EXCLUDED_IDS for cat_id in ann_ids) and any(cat_id in TARGET_IDS for cat_id in ann_ids):
181
+ self.filtered_samples.append({
182
+ 'coco_id': sample['coco_id'],
183
+ 'eeg_data': np.array(sample['EEG'], dtype=np.float32)
184
+ })
185
+ print(f"Loaded {len(self.filtered_samples)} filtered test samples.")
186
+ if not self.filtered_samples: raise RuntimeError("No suitable test samples found.")
187
+
188
+ def get_eeg_windows(self, sample_info):
189
+ eeg_data = sample_info['eeg_data']
190
+ eeg_windows = []
191
+ for start_ms, end_ms, _ in TIME_WINDOWS:
192
+ start_idx, end_idx = int(start_ms / 1000 * EEG_SAMPLE_RATE), int(end_ms / 1000 * EEG_SAMPLE_RATE)
193
+ n_timepoints = end_idx - start_idx
194
+ window = eeg_data[:, start_idx:end_idx] if eeg_data.shape[1] >= end_idx else eeg_data[:, start_idx:]
195
+
196
+ if window.shape[1] != n_timepoints:
197
+ pad_width = n_timepoints - window.shape[1]
198
+ window = np.pad(window, ((0,0), (0, pad_width)), 'edge') if pad_width > 0 else window[:, :n_timepoints]
199
+
200
+ window = (window - window.mean(axis=1, keepdims=True)) / (window.std(axis=1, keepdims=True) + 1e-8)
201
+ eeg_windows.append(window)
202
+ return eeg_windows
203
+
204
+ def get_random_sample_info(self):
205
+ return random.choice(self.filtered_samples)
206
+
207
+ # === VIEWER APPLICATION ===
208
+ class CalciumBridgeViewer(tk.Tk):
209
+ def __init__(self):
210
+ super().__init__()
211
+ self.title("Calcium-Bridge EEG Constraint Viewer V2 (Extended Window)")
212
+ self.geometry("2000x1000")
213
+ self.model, self.test_data = None, None
214
+ self.setup_gui()
215
+
216
+ def setup_gui(self):
217
+ control_frame = ttk.Frame(self); control_frame.pack(pady=10, padx=10, fill=tk.X)
218
+ ttk.Label(control_frame, text="COCO Path:").pack(side=tk.LEFT, padx=5)
219
+ self.coco_var = tk.StringVar(); ttk.Entry(control_frame, textvariable=self.coco_var, width=20).pack(side=tk.LEFT, padx=2)
220
+ ttk.Button(control_frame, text="Browse", command=self.browse_coco).pack(side=tk.LEFT, padx=5)
221
+ ttk.Label(control_frame, text="Annotations:").pack(side=tk.LEFT, padx=5)
222
+ self.ann_var = tk.StringVar(); ttk.Entry(control_frame, textvariable=self.ann_var, width=20).pack(side=tk.LEFT, padx=2)
223
+ ttk.Button(control_frame, text="Browse", command=self.browse_ann).pack(side=tk.LEFT, padx=5)
224
+ ttk.Button(control_frame, text="Load V2 Model", command=self.load_model).pack(side=tk.LEFT, padx=20)
225
+ self.test_btn = ttk.Button(control_frame, text="Test Random Sample", command=self.test_sample, state=tk.DISABLED); self.test_btn.pack(side=tk.LEFT, padx=5)
226
+ self.status_label = tk.Label(control_frame, text="Model: Not loaded", fg="gray"); self.status_label.pack(side=tk.LEFT, padx=20)
227
+
228
+ main_paned = ttk.PanedWindow(self, orient=tk.HORIZONTAL); main_paned.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
229
+ image_frame = ttk.Frame(main_paned, width=400); main_paned.add(image_frame, weight=0)
230
+ ttk.Label(image_frame, text="COCO Image", font=("Arial", 12, "bold")).pack(pady=5)
231
+ self.image_canvas = tk.Canvas(image_frame, width=400, height=400, bg='lightgray'); self.image_canvas.pack()
232
+ self.coco_id_label = ttk.Label(image_frame, text="COCO ID: N/A"); self.coco_id_label.pack(pady=5)
233
+
234
+ self.notebook = ttk.Notebook(main_paned); main_paned.add(self.notebook, weight=1)
235
+ self.create_tabs()
236
+
237
+ def create_tabs(self):
238
+ self.constraint_fig, self.constraint_canvas = self.create_tab("Constraint Satisfaction", "How predictions crystallize as constraints are satisfied")
239
+ self.calcium_fig, self.calcium_canvas = self.create_tab("Calcium Attention", "Calcium state evolution: What the model 'focuses on' at each stage")
240
+ self.eeg_fig, self.eeg_canvas = self.create_tab("EEG Heatmaps", "Raw EEG signals for each time window")
241
+
242
+ def create_tab(self, title, description):
243
+ tab = ttk.Frame(self.notebook); self.notebook.add(tab, text=title)
244
+ ttk.Label(tab, text=description, font=("Arial", 11)).pack(pady=5)
245
+ fig = plt.Figure()
246
+ canvas = FigureCanvasTkAgg(fig, tab); canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
247
+ return fig, canvas
248
+
249
+ def browse_coco(self):
250
+ path = filedialog.askdirectory(); self.coco_var.set(path); self.coco_path = path
251
+
252
+ def browse_ann(self):
253
+ path = filedialog.askopenfilename(filetypes=[("JSON", "*.json")]); self.ann_var.set(path); self.annotations_path = path
254
+
255
+ def load_model(self):
256
+ model_path = filedialog.askopenfilename(filetypes=[("PyTorch Model", "*.pth")], title="Select calcium_bridge_eeg_model_v2.pth")
257
+ if not model_path or not self.annotations_path: return
258
+ try:
259
+ checkpoint = torch.load(model_path, map_location=DEVICE)
260
+ self.model = TemporalConstraintEEGModel().to(DEVICE)
261
+ self.model.load_state_dict(checkpoint['model_state_dict'])
262
+ self.model.eval()
263
+ self.test_data = FilteredTestDataset(self.annotations_path)
264
+ self.status_label.config(text="Model: V2 Loaded ✓", fg="green")
265
+ self.test_btn.config(state=tk.NORMAL)
266
+ except Exception as e: messagebox.showerror("Error", f"Failed to load model:\n{e}"); print(traceback.format_exc())
267
+
268
+ def _fetch_image(self, coco_id):
269
+ formatted_id = f"{coco_id:012d}.jpg"
270
+ for s in ["train2017", "val2017", "test2017"]:
271
+ path = os.path.join(self.coco_path, s, formatted_id)
272
+ if os.path.exists(path): return Image.open(path).convert("RGB")
273
+ return None
274
+
275
+ def test_sample(self):
276
+ if not self.model: return
277
+ try:
278
+ sample_info = self.test_data.get_random_sample_info()
279
+ image = self._fetch_image(sample_info['coco_id'])
280
+ if image: self.display_image(image, sample_info['coco_id'])
281
+
282
+ eeg_windows_np = self.test_data.get_eeg_windows(sample_info)
283
+ eeg_windows = [torch.from_numpy(w).unsqueeze(0).to(DEVICE) for w in eeg_windows_np]
284
+
285
+ with torch.no_grad():
286
+ logits, ca_history, _, window_logits = self.model(eeg_windows, return_intermediates=True)
287
+
288
+ self.visualize_constraint_satisfaction(window_logits, logits)
289
+ self.visualize_calcium_evolution(ca_history)
290
+ self.visualize_eeg_heatmaps(eeg_windows_np)
291
+ except Exception as e: messagebox.showerror("Error", f"Failed to process sample:\n{e}"); print(traceback.format_exc())
292
+
293
+ def display_image(self, image, coco_id):
294
+ ratio = min(400/image.width, 400/image.height)
295
+ resized = image.resize((int(image.width * ratio), int(image.height * ratio)), Image.LANCZOS)
296
+ self.pil_image_tk = ImageTk.PhotoImage(resized)
297
+ self.image_canvas.create_image(200, 200, image=self.pil_image_tk)
298
+ self.coco_id_label.config(text=f"COCO ID: {coco_id}")
299
+
300
+ def visualize_constraint_satisfaction(self, window_logits, final_logits):
301
+ self.constraint_fig.clear()
302
+ cat_list = list(TARGET_CATEGORIES.keys())
303
+ n_windows = len(window_logits)
304
+ final_probs = torch.sigmoid(final_logits).squeeze(0).cpu().numpy()
305
+ top_indices = np.argsort(final_probs)[::-1][:10]
306
+ axes = self.constraint_fig.subplots(1, n_windows + 1)
307
+
308
+ for i, (ax, wl) in enumerate(zip(axes[:-1], window_logits)):
309
+ probs = torch.sigmoid(wl).squeeze(0).cpu().numpy()[top_indices]
310
+ ax.barh([cat_list[idx] for idx in top_indices], probs, color='steelblue')
311
+ ax.set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
312
+ ax.set_xlim(0, 1); ax.invert_yaxis(); ax.tick_params(axis='y', labelsize=8)
313
+
314
+ axes[-1].barh([cat_list[idx] for idx in top_indices], final_probs[top_indices], color='darkgreen')
315
+ axes[-1].set_title("Final\n(Combined)", fontsize=10); axes[-1].set_xlim(0, 1); axes[-1].invert_yaxis(); axes[-1].tick_params(axis='y', labelsize=8)
316
+ self.constraint_fig.suptitle("Constraint Satisfaction: Predictions Crystallizing Over Time", fontsize=14, fontweight='bold')
317
+ self.constraint_fig.tight_layout(); self.constraint_canvas.draw()
318
+
319
+ def visualize_calcium_evolution(self, ca_history):
320
+ self.calcium_fig.clear()
321
+ n_windows = len(ca_history)
322
+ axes = self.calcium_fig.subplots(2, n_windows)
323
+
324
+ for i, ca_state in enumerate(ca_history):
325
+ ca_np = ca_state.squeeze(0).cpu().numpy()
326
+ top_20_idx = np.argsort(ca_np)[::-1][:20]
327
+ axes[0, i].plot(ca_np, 'r'); axes[0, i].fill_between(range(len(ca_np)), ca_np, color='r', alpha=0.3)
328
+ axes[0, i].set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
329
+ axes[1, i].barh([f"F{idx}" for idx in top_20_idx], ca_np[top_20_idx], color='darkred')
330
+ axes[1, i].invert_yaxis(); axes[1, i].tick_params(axis='y', labelsize=7)
331
+ self.calcium_fig.suptitle("Calcium Attention: What the Model Focuses On", fontsize=14, fontweight='bold')
332
+ self.calcium_fig.tight_layout(); self.calcium_canvas.draw()
333
+
334
+ def visualize_eeg_heatmaps(self, eeg_windows_np):
335
+ self.eeg_fig.clear()
336
+ n_windows = len(eeg_windows_np)
337
+ axes = self.eeg_fig.subplots(1, n_windows)
338
+
339
+ for i, (ax, eeg_data) in enumerate(zip(axes, eeg_windows_np)):
340
+ im = ax.imshow(eeg_data, aspect='auto', cmap='RdBu_r', vmin=-3, vmax=3)
341
+ ax.set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
342
+ if i == 0: ax.set_ylabel("Channel")
343
+ self.eeg_fig.colorbar(im, ax=ax) # CORRECTED
344
+ self.eeg_fig.suptitle("Raw EEG Signals by Time Window", fontsize=14, fontweight='bold')
345
+ self.eeg_fig.tight_layout(); self.eeg_canvas.draw()
346
+
347
+ if __name__ == "__main__":
348
+ app = CalciumBridgeViewer()
349
+ app.mainloop()