|
|
"""
|
|
|
Calcium-Bridged Temporal EEG Decoder (V2 - Extended Time Window)
|
|
|
Integrates Phase-Calcium-Latent constraint satisfaction dynamics with EEG temporal windows.
|
|
|
|
|
|
Core Concept: Each EEG time window is processed by a constraint solver whose
|
|
|
calcium/W state carries over to initialize the next window, modeling how the brain
|
|
|
sequentially satisfies perceptual constraints.
|
|
|
|
|
|
V2 Update: The time window has been extended to 550ms based on ERP analysis from the
|
|
|
Alljoined1 paper, adding a 'CognitiveEvaluation' stage to capture late-stage
|
|
|
semantic and working memory signals.
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import json
|
|
|
import tkinter as tk
|
|
|
from tkinter import ttk, filedialog, messagebox
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
import torch.nn.functional as F
|
|
|
import numpy as np
|
|
|
import threading
|
|
|
import queue
|
|
|
from pathlib import Path
|
|
|
from collections import defaultdict
|
|
|
import matplotlib.pyplot as plt
|
|
|
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
|
|
|
|
|
try:
|
|
|
from datasets import load_dataset
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
except ImportError as e:
|
|
|
print(f"Missing dependency: {e}")
|
|
|
exit()
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
EEG_SAMPLE_RATE = 512
|
|
|
BATCH_SIZE = 64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TIME_WINDOWS = [
|
|
|
(50, 150, "EarlyVisual"),
|
|
|
(150, 250, "MidFeature"),
|
|
|
(250, 350, "LateSemantic"),
|
|
|
(350, 550, "CognitiveEvaluation")
|
|
|
]
|
|
|
|
|
|
TARGET_CATEGORIES = {
|
|
|
'elephant': 22, 'giraffe': 25, 'bear': 23, 'zebra': 24,
|
|
|
'cow': 21, 'sheep': 20, 'horse': 19, 'dog': 18, 'cat': 17, 'bird': 16,
|
|
|
'airplane': 5, 'train': 7, 'boat': 9, 'bus': 6, 'truck': 8,
|
|
|
'motorcycle': 4, 'bicycle': 2, 'car': 3,
|
|
|
'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13,
|
|
|
'parking meter': 14, 'bench': 15,
|
|
|
}
|
|
|
|
|
|
CATEGORY_NAMES = {v: k for k, v in TARGET_CATEGORIES.items()}
|
|
|
|
|
|
class CalciumAttentionModule(nn.Module):
|
|
|
"""
|
|
|
Phase-Calcium-Latent dynamics for one time window.
|
|
|
Models constraint satisfaction via neuromorphic oscillator dynamics.
|
|
|
"""
|
|
|
def __init__(self, n_features, d_model=256):
|
|
|
super().__init__()
|
|
|
self.n_features = n_features
|
|
|
self.d_model = d_model
|
|
|
|
|
|
|
|
|
self.phase_proj = nn.Linear(n_features, d_model)
|
|
|
|
|
|
|
|
|
self.ca_gate = nn.Sequential(
|
|
|
nn.Linear(d_model, d_model // 2),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
|
|
|
self.W = nn.Parameter(torch.randn(d_model, d_model) * 0.01)
|
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(d_model)
|
|
|
|
|
|
def forward(self, x, prev_ca=None, prev_W=None):
|
|
|
"""
|
|
|
x: Input features [batch, n_features]
|
|
|
prev_ca: Previous window's calcium state [batch, d_model]
|
|
|
prev_W: Previous window's coupling matrix [d_model, d_model]
|
|
|
|
|
|
Returns: features, calcium_state, W_matrix
|
|
|
"""
|
|
|
batch_size = x.size(0)
|
|
|
|
|
|
|
|
|
phi = self.phase_proj(x)
|
|
|
|
|
|
|
|
|
if prev_ca is None:
|
|
|
ca = torch.zeros(batch_size, self.d_model, device=x.device)
|
|
|
else:
|
|
|
ca = prev_ca.clone()
|
|
|
|
|
|
|
|
|
W = self.W if prev_W is None else prev_W
|
|
|
|
|
|
|
|
|
|
|
|
coherence = torch.abs(torch.cos(phi[:, :, None] - phi[:, None, :]))
|
|
|
ca_update = torch.mean(coherence, dim=2)
|
|
|
ca = ca * 0.95 + ca_update * 0.05
|
|
|
|
|
|
|
|
|
ca_gate = self.ca_gate(ca)
|
|
|
|
|
|
|
|
|
|
|
|
coupled = torch.matmul(phi, W)
|
|
|
|
|
|
|
|
|
ca_gate_full = torch.cat([ca_gate, ca_gate], dim=1)
|
|
|
features = coupled * ca_gate_full
|
|
|
|
|
|
|
|
|
features = self.norm(features + phi)
|
|
|
|
|
|
return features, ca, W
|
|
|
|
|
|
|
|
|
class TemporalConstraintEEGModel(nn.Module):
|
|
|
"""
|
|
|
Sequential constraint satisfaction across EEG time windows.
|
|
|
Each window is a constraint solver whose state primes the next.
|
|
|
(Dynamically sized based on TIME_WINDOWS constant)
|
|
|
"""
|
|
|
def __init__(self, n_channels=64, num_classes=len(TARGET_CATEGORIES)):
|
|
|
super().__init__()
|
|
|
self.n_channels = n_channels
|
|
|
|
|
|
|
|
|
self.window_encoders = nn.ModuleList([
|
|
|
self._build_cnn_encoder() for _ in TIME_WINDOWS
|
|
|
])
|
|
|
|
|
|
|
|
|
self.ca_modules = nn.ModuleList([
|
|
|
CalciumAttentionModule(256, d_model=256) for _ in TIME_WINDOWS
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential(
|
|
|
nn.Linear(256 * len(TIME_WINDOWS), 512),
|
|
|
nn.BatchNorm1d(512),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(0.3),
|
|
|
nn.Linear(512, num_classes)
|
|
|
)
|
|
|
|
|
|
def _build_cnn_encoder(self):
|
|
|
"""Simple CNN for one time window"""
|
|
|
return nn.Sequential(
|
|
|
nn.Conv1d(self.n_channels, 128, kernel_size=15, padding=7),
|
|
|
nn.BatchNorm1d(128),
|
|
|
nn.ELU(),
|
|
|
nn.MaxPool1d(2),
|
|
|
nn.Conv1d(128, 256, kernel_size=7, padding=3),
|
|
|
nn.BatchNorm1d(256),
|
|
|
nn.ELU(),
|
|
|
nn.AdaptiveAvgPool1d(1)
|
|
|
)
|
|
|
|
|
|
def forward(self, eeg_windows):
|
|
|
"""
|
|
|
eeg_windows: List of tensors [batch, channels, timepoints] for each window
|
|
|
|
|
|
Returns: logits, calcium_states (for visualization/analysis)
|
|
|
"""
|
|
|
batch_size = eeg_windows[0].size(0)
|
|
|
|
|
|
|
|
|
window_features = []
|
|
|
ca_state = None
|
|
|
W_state = None
|
|
|
ca_history = []
|
|
|
|
|
|
for i, (encoder, ca_module, eeg_window) in enumerate(
|
|
|
zip(self.window_encoders, self.ca_modules, eeg_windows)
|
|
|
):
|
|
|
|
|
|
cnn_features = encoder(eeg_window).squeeze(-1)
|
|
|
|
|
|
|
|
|
features, ca_state, W_state = ca_module(
|
|
|
cnn_features,
|
|
|
prev_ca=ca_state,
|
|
|
prev_W=W_state
|
|
|
)
|
|
|
|
|
|
window_features.append(features)
|
|
|
ca_history.append(ca_state.detach().cpu().numpy())
|
|
|
|
|
|
|
|
|
combined = torch.cat(window_features, dim=1)
|
|
|
|
|
|
|
|
|
logits = self.classifier(combined)
|
|
|
|
|
|
return logits, ca_history
|
|
|
|
|
|
|
|
|
class CalciumEEGDataset(Dataset):
|
|
|
"""Dataset that provides EEG data split by time windows"""
|
|
|
def __init__(self, coco_path, annotations_path, split='train',
|
|
|
max_samples=None, trials_to_average=1):
|
|
|
self.coco_path = Path(coco_path)
|
|
|
|
|
|
|
|
|
print(f"Loading Alljoined ({split})...")
|
|
|
self.dataset = load_dataset("Alljoined/05_125", split=split, streaming=False)
|
|
|
|
|
|
if max_samples:
|
|
|
self.dataset = self.dataset.select(range(min(int(max_samples), len(self.dataset))))
|
|
|
|
|
|
|
|
|
print(f"Loading COCO annotations...")
|
|
|
with open(annotations_path, 'r') as f:
|
|
|
coco_data = json.load(f)
|
|
|
|
|
|
self.image_categories = defaultdict(set)
|
|
|
for ann in coco_data['annotations']:
|
|
|
img_id = ann['image_id']
|
|
|
if ann['category_id'] in CATEGORY_NAMES:
|
|
|
self.image_categories[img_id].add(ann['category_id'])
|
|
|
|
|
|
|
|
|
print("Pre-caching EEG data...")
|
|
|
self.samples = []
|
|
|
for idx, sample in enumerate(self.dataset):
|
|
|
coco_id = sample['coco_id']
|
|
|
if coco_id in self.image_categories and len(self.image_categories[coco_id]) > 0:
|
|
|
label = torch.zeros(len(TARGET_CATEGORIES))
|
|
|
for cat_id in self.image_categories[coco_id]:
|
|
|
if cat_id in CATEGORY_NAMES:
|
|
|
cat_idx = list(TARGET_CATEGORIES.values()).index(cat_id)
|
|
|
label[cat_idx] = 1.0
|
|
|
|
|
|
if label.sum() > 0:
|
|
|
self.samples.append((idx, label))
|
|
|
|
|
|
print(f"Cached {len(self.samples)} samples")
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.samples)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
sample_idx, label = self.samples[idx]
|
|
|
sample = self.dataset[sample_idx]
|
|
|
|
|
|
eeg_data = np.array(sample['EEG'], dtype=np.float32)
|
|
|
|
|
|
|
|
|
eeg_windows = []
|
|
|
for start_ms, end_ms, _ in TIME_WINDOWS:
|
|
|
start_idx = int((start_ms / 1000.0) * EEG_SAMPLE_RATE)
|
|
|
end_idx = int((end_ms / 1000.0) * EEG_SAMPLE_RATE)
|
|
|
|
|
|
if eeg_data.shape[1] >= end_idx:
|
|
|
window = eeg_data[:, start_idx:end_idx]
|
|
|
else:
|
|
|
window = eeg_data[:, start_idx:]
|
|
|
|
|
|
if window.shape[1] < (end_idx - start_idx):
|
|
|
pad_width = (end_idx - start_idx) - window.shape[1]
|
|
|
window = np.pad(window, ((0,0), (0, pad_width)), mode='edge')
|
|
|
|
|
|
|
|
|
window = (window - window.mean(axis=1, keepdims=True)) / \
|
|
|
(window.std(axis=1, keepdims=True) + 1e-8)
|
|
|
|
|
|
eeg_windows.append(torch.from_numpy(window).float())
|
|
|
|
|
|
return eeg_windows, label
|
|
|
|
|
|
|
|
|
class CalciumEEGTrainerGUI(tk.Tk):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.title("Calcium-Bridged Temporal EEG Decoder V2")
|
|
|
self.geometry("1200x850")
|
|
|
|
|
|
self.coco_path = ""
|
|
|
self.annotations_path = ""
|
|
|
self.train_thread = None
|
|
|
self.stop_flag = threading.Event()
|
|
|
self.log_queue = queue.Queue()
|
|
|
|
|
|
self.setup_gui()
|
|
|
self.process_logs()
|
|
|
|
|
|
def setup_gui(self):
|
|
|
|
|
|
title = tk.Label(self, text="Calcium-Bridged Temporal EEG Decoder (V2 - Extended Window)",
|
|
|
font=("Arial", 14, "bold"))
|
|
|
title.pack(pady=10)
|
|
|
|
|
|
info = tk.Label(self,
|
|
|
text="Sequential constraint satisfaction across 4 ERP time windows up to 550ms\n"
|
|
|
"Calcium/W state from early windows primes later windows",
|
|
|
fg="blue", font=("Arial", 9))
|
|
|
info.pack(pady=5)
|
|
|
|
|
|
|
|
|
path_frame = ttk.LabelFrame(self, text="Dataset")
|
|
|
path_frame.pack(pady=5, padx=10, fill=tk.X)
|
|
|
|
|
|
tk.Label(path_frame, text="COCO:").grid(row=0, column=0, sticky=tk.W, padx=5, pady=3)
|
|
|
self.coco_var = tk.StringVar()
|
|
|
ttk.Entry(path_frame, textvariable=self.coco_var, width=50).grid(row=0, column=1, padx=5)
|
|
|
ttk.Button(path_frame, text="Browse", command=self.browse_coco).grid(row=0, column=2)
|
|
|
|
|
|
tk.Label(path_frame, text="Annotations:").grid(row=1, column=0, sticky=tk.W, padx=5, pady=3)
|
|
|
self.ann_var = tk.StringVar()
|
|
|
ttk.Entry(path_frame, textvariable=self.ann_var, width=50).grid(row=1, column=1, padx=5)
|
|
|
ttk.Button(path_frame, text="Browse", command=self.browse_ann).grid(row=1, column=2)
|
|
|
|
|
|
|
|
|
settings_frame = ttk.LabelFrame(self, text="Training Settings")
|
|
|
settings_frame.pack(pady=5, padx=10, fill=tk.X)
|
|
|
|
|
|
tk.Label(settings_frame, text="Max Samples:").grid(row=0, column=0, padx=5)
|
|
|
self.max_var = tk.IntVar(value=3000)
|
|
|
tk.Spinbox(settings_frame, from_=1000, to=10000, increment=1000,
|
|
|
textvariable=self.max_var, width=10).grid(row=0, column=1)
|
|
|
|
|
|
tk.Label(settings_frame, text="Epochs:").grid(row=0, column=2, padx=5)
|
|
|
self.epochs_var = tk.IntVar(value=100)
|
|
|
tk.Spinbox(settings_frame, from_=50, to=500, increment=50,
|
|
|
textvariable=self.epochs_var, width=10).grid(row=0, column=3)
|
|
|
|
|
|
|
|
|
|
|
|
windows_frame = ttk.LabelFrame(self, text="Constraint Satisfaction Stages")
|
|
|
windows_frame.pack(pady=5, padx=10, fill=tk.X)
|
|
|
|
|
|
for start, end, label in TIME_WINDOWS:
|
|
|
desc = {
|
|
|
"EarlyVisual": "Low-level visual features (edges, textures)",
|
|
|
"MidFeature": "Mid-level binding (parts, shapes)",
|
|
|
"LateSemantic": "High-level semantics (concepts, context)",
|
|
|
"CognitiveEvaluation": "Memory, context check, final decision"
|
|
|
}
|
|
|
tk.Label(windows_frame,
|
|
|
text=f"{label} ({start}-{end}ms): {desc[label]}",
|
|
|
font=("Courier", 9)).pack(anchor=tk.W, padx=10, pady=2)
|
|
|
|
|
|
|
|
|
btn_frame = tk.Frame(self)
|
|
|
btn_frame.pack(pady=10)
|
|
|
|
|
|
self.train_btn = tk.Button(btn_frame, text="Train Extended Model (V2)",
|
|
|
command=self.start_train,
|
|
|
bg="#4CAF50", fg="white", font=("Arial", 10, "bold"))
|
|
|
self.train_btn.pack(side=tk.LEFT, padx=5)
|
|
|
|
|
|
self.stop_btn = tk.Button(btn_frame, text="Stop",
|
|
|
command=self.stop_train,
|
|
|
bg="#f44336", fg="white",
|
|
|
state=tk.DISABLED)
|
|
|
self.stop_btn.pack(side=tk.LEFT, padx=5)
|
|
|
|
|
|
|
|
|
self.progress = ttk.Progressbar(self, mode='determinate')
|
|
|
self.progress.pack(fill=tk.X, padx=10, pady=5)
|
|
|
|
|
|
|
|
|
log_frame = ttk.LabelFrame(self, text="Training Log")
|
|
|
log_frame.pack(pady=5, padx=10, fill=tk.BOTH, expand=True)
|
|
|
|
|
|
self.log_text = tk.Text(log_frame, height=20, bg='black', fg='lightgreen',
|
|
|
font=('Courier', 8))
|
|
|
self.log_text.pack(fill=tk.BOTH, expand=True)
|
|
|
|
|
|
def browse_coco(self):
|
|
|
path = filedialog.askdirectory()
|
|
|
if path:
|
|
|
self.coco_var.set(path)
|
|
|
self.coco_path = path
|
|
|
|
|
|
def browse_ann(self):
|
|
|
path = filedialog.askopenfilename(filetypes=[("JSON", "*.json")])
|
|
|
if path:
|
|
|
self.ann_var.set(path)
|
|
|
self.annotations_path = path
|
|
|
|
|
|
def log(self, msg):
|
|
|
self.log_queue.put(msg)
|
|
|
|
|
|
def process_logs(self):
|
|
|
try:
|
|
|
while not self.log_queue.empty():
|
|
|
msg = self.log_queue.get_nowait()
|
|
|
self.log_text.insert(tk.END, msg + "\n")
|
|
|
self.log_text.see(tk.END)
|
|
|
except queue.Empty:
|
|
|
pass
|
|
|
self.after(100, self.process_logs)
|
|
|
|
|
|
def start_train(self):
|
|
|
if not self.coco_path or not self.annotations_path:
|
|
|
messagebox.showerror("Error", "Select paths first")
|
|
|
return
|
|
|
|
|
|
self.stop_flag.clear()
|
|
|
self.train_btn.config(state=tk.DISABLED)
|
|
|
self.stop_btn.config(state=tk.NORMAL)
|
|
|
|
|
|
self.train_thread = threading.Thread(target=self._train_model, daemon=True)
|
|
|
self.train_thread.start()
|
|
|
|
|
|
def stop_train(self):
|
|
|
self.stop_flag.set()
|
|
|
|
|
|
def _train_model(self):
|
|
|
try:
|
|
|
self.log("="*70)
|
|
|
self.log("CALCIUM-BRIDGED TEMPORAL EEG DECODER (V2 - Extended Window)")
|
|
|
self.log("="*70)
|
|
|
self.log("\nConcept: Sequential constraint satisfaction across FOUR time windows")
|
|
|
self.log("Now capturing late-stage cognitive evaluation signals up to 550ms\n")
|
|
|
|
|
|
|
|
|
dataset = CalciumEEGDataset(
|
|
|
self.coco_path,
|
|
|
self.annotations_path,
|
|
|
'train',
|
|
|
self.max_var.get()
|
|
|
)
|
|
|
|
|
|
total = len(dataset)
|
|
|
train_size = int(0.8 * total)
|
|
|
val_size = total - train_size
|
|
|
|
|
|
train_set, val_set = torch.utils.data.random_split(
|
|
|
dataset, [train_size, val_size],
|
|
|
generator=torch.Generator().manual_seed(42)
|
|
|
)
|
|
|
|
|
|
self.log(f"Train: {train_size}, Val: {val_size}")
|
|
|
|
|
|
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
|
|
|
shuffle=True, num_workers=0, pin_memory=True)
|
|
|
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE,
|
|
|
shuffle=False, num_workers=0, pin_memory=True)
|
|
|
|
|
|
|
|
|
model = TemporalConstraintEEGModel().to(DEVICE)
|
|
|
self.log(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
|
|
|
criterion = nn.BCEWithLogitsLoss()
|
|
|
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)
|
|
|
|
|
|
best_val_loss = float('inf')
|
|
|
|
|
|
for epoch in range(self.epochs_var.get()):
|
|
|
if self.stop_flag.is_set():
|
|
|
break
|
|
|
|
|
|
|
|
|
model.train()
|
|
|
train_loss = 0
|
|
|
for eeg_windows, labels in train_loader:
|
|
|
if self.stop_flag.is_set():
|
|
|
break
|
|
|
|
|
|
eeg_windows = [w.to(DEVICE) for w in eeg_windows]
|
|
|
labels = labels.to(DEVICE)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
logits, _ = model(eeg_windows)
|
|
|
loss = criterion(logits, labels)
|
|
|
loss.backward()
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
|
optimizer.step()
|
|
|
|
|
|
train_loss += loss.item()
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
val_loss = 0
|
|
|
with torch.no_grad():
|
|
|
for eeg_windows, labels in val_loader:
|
|
|
eeg_windows = [w.to(DEVICE) for w in eeg_windows]
|
|
|
labels = labels.to(DEVICE)
|
|
|
logits, _ = model(eeg_windows)
|
|
|
loss = criterion(logits, labels)
|
|
|
val_loss += loss.item()
|
|
|
|
|
|
train_loss /= len(train_loader)
|
|
|
val_loss /= len(val_loader)
|
|
|
|
|
|
scheduler.step()
|
|
|
|
|
|
self.progress['value'] = ((epoch + 1) / self.epochs_var.get()) * 100
|
|
|
|
|
|
if epoch % 5 == 0:
|
|
|
self.log(f"Epoch {epoch+1}/{self.epochs_var.get()}: "
|
|
|
f"TrLoss={train_loss:.4f} ValLoss={val_loss:.4f}")
|
|
|
|
|
|
if val_loss < best_val_loss:
|
|
|
best_val_loss = val_loss
|
|
|
torch.save({
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
'val_loss': val_loss,
|
|
|
'epoch': epoch
|
|
|
}, "calcium_bridge_eeg_model_v2.pth")
|
|
|
if epoch % 5 == 0:
|
|
|
self.log(f" -> Saved V2 model (val_loss={val_loss:.4f})")
|
|
|
|
|
|
self.log("\n" + "="*70)
|
|
|
self.log("TRAINING COMPLETE")
|
|
|
self.log(f"Best Val Loss: {best_val_loss:.4f}")
|
|
|
self.log("="*70)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.log(f"ERROR: {e}")
|
|
|
import traceback
|
|
|
self.log(traceback.format_exc())
|
|
|
finally:
|
|
|
self.train_btn.config(state=tk.NORMAL)
|
|
|
self.stop_btn.config(state=tk.DISABLED)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
app = CalciumEEGTrainerGUI()
|
|
|
app.mainloop() |