CalciumBridgeEEGConstraintViewer / pkas_cal_trainer_gemini.py
Aluode's picture
Upload 3 files
258e525 verified
"""
Calcium-Bridged Temporal EEG Decoder (V2 - Extended Time Window)
Integrates Phase-Calcium-Latent constraint satisfaction dynamics with EEG temporal windows.
Core Concept: Each EEG time window is processed by a constraint solver whose
calcium/W state carries over to initialize the next window, modeling how the brain
sequentially satisfies perceptual constraints.
V2 Update: The time window has been extended to 550ms based on ERP analysis from the
Alljoined1 paper, adding a 'CognitiveEvaluation' stage to capture late-stage
semantic and working memory signals.
"""
import os
import json
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
import threading
import queue
from pathlib import Path
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
try:
from datasets import load_dataset
torch.backends.cudnn.benchmark = True
except ImportError as e:
print(f"Missing dependency: {e}")
exit()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EEG_SAMPLE_RATE = 512
BATCH_SIZE = 64
# --- CRITICAL CHANGE V2 ---
# Extended temporal windows to capture later cognitive processing (P300/N400/P600)
# This aligns the model's "attention span" with the neuroscience data.
TIME_WINDOWS = [
(50, 150, "EarlyVisual"), # Low-level visual constraints (P100)
(150, 250, "MidFeature"), # Mid-level binding (N170/P200)
(250, 350, "LateSemantic"), # High-level semantics (P300/N400 start)
(350, 550, "CognitiveEvaluation") # Deeper context, memory, final check (P300/P600)
]
TARGET_CATEGORIES = {
'elephant': 22, 'giraffe': 25, 'bear': 23, 'zebra': 24,
'cow': 21, 'sheep': 20, 'horse': 19, 'dog': 18, 'cat': 17, 'bird': 16,
'airplane': 5, 'train': 7, 'boat': 9, 'bus': 6, 'truck': 8,
'motorcycle': 4, 'bicycle': 2, 'car': 3,
'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13,
'parking meter': 14, 'bench': 15,
}
CATEGORY_NAMES = {v: k for k, v in TARGET_CATEGORIES.items()}
class CalciumAttentionModule(nn.Module):
"""
Phase-Calcium-Latent dynamics for one time window.
Models constraint satisfaction via neuromorphic oscillator dynamics.
"""
def __init__(self, n_features, d_model=256):
super().__init__()
self.n_features = n_features
self.d_model = d_model
# Phase dynamics (Kuramoto-like)
self.phase_proj = nn.Linear(n_features, d_model)
# Calcium dynamics (gating/attention)
self.ca_gate = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.Sigmoid()
)
# Latent coupling matrix (W) - learned constraint structure
self.W = nn.Parameter(torch.randn(d_model, d_model) * 0.01)
# Layer norm for stability
self.norm = nn.LayerNorm(d_model)
def forward(self, x, prev_ca=None, prev_W=None):
"""
x: Input features [batch, n_features]
prev_ca: Previous window's calcium state [batch, d_model]
prev_W: Previous window's coupling matrix [d_model, d_model]
Returns: features, calcium_state, W_matrix
"""
batch_size = x.size(0)
# Phase projection
phi = self.phase_proj(x) # [batch, d_model]
# Initialize or carry over calcium
if prev_ca is None:
ca = torch.zeros(batch_size, self.d_model, device=x.device)
else:
ca = prev_ca.clone()
# Initialize or carry over W (coupling structure)
W = self.W if prev_W is None else prev_W
# Calcium accumulation (coherence-based)
# High when features are aligned (low when conflicting)
coherence = torch.abs(torch.cos(phi[:, :, None] - phi[:, None, :]))
ca_update = torch.mean(coherence, dim=2) # [batch, d_model]
ca = ca * 0.95 + ca_update * 0.05 # Temporal integration
# Calcium-gated attention
ca_gate = self.ca_gate(ca) # [batch, d_model//2]
# Apply constraint coupling (W matrix)
# This is where "mutual constraint satisfaction" happens
coupled = torch.matmul(phi, W) # [batch, d_model]
# Gate the coupling by calcium (only attend where calcium is high)
ca_gate_full = torch.cat([ca_gate, ca_gate], dim=1) # Expand to d_model
features = coupled * ca_gate_full
# Normalize
features = self.norm(features + phi) # Residual connection
return features, ca, W
class TemporalConstraintEEGModel(nn.Module):
"""
Sequential constraint satisfaction across EEG time windows.
Each window is a constraint solver whose state primes the next.
(Dynamically sized based on TIME_WINDOWS constant)
"""
def __init__(self, n_channels=64, num_classes=len(TARGET_CATEGORIES)):
super().__init__()
self.n_channels = n_channels
# CNN feature extractors for each time window
self.window_encoders = nn.ModuleList([
self._build_cnn_encoder() for _ in TIME_WINDOWS
])
# Calcium-attention modules for each window
self.ca_modules = nn.ModuleList([
CalciumAttentionModule(256, d_model=256) for _ in TIME_WINDOWS
])
# --- CRITICAL CHANGE V2 ---
# The input layer is now automatically larger (256 * 4) because len(TIME_WINDOWS) is 4.
self.classifier = nn.Sequential(
nn.Linear(256 * len(TIME_WINDOWS), 512),
nn.BatchNorm1d(512),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def _build_cnn_encoder(self):
"""Simple CNN for one time window"""
return nn.Sequential(
nn.Conv1d(self.n_channels, 128, kernel_size=15, padding=7),
nn.BatchNorm1d(128),
nn.ELU(),
nn.MaxPool1d(2),
nn.Conv1d(128, 256, kernel_size=7, padding=3),
nn.BatchNorm1d(256),
nn.ELU(),
nn.AdaptiveAvgPool1d(1)
)
def forward(self, eeg_windows):
"""
eeg_windows: List of tensors [batch, channels, timepoints] for each window
Returns: logits, calcium_states (for visualization/analysis)
"""
batch_size = eeg_windows[0].size(0)
# Process windows sequentially with calcium/W carryover
window_features = []
ca_state = None
W_state = None
ca_history = []
for i, (encoder, ca_module, eeg_window) in enumerate(
zip(self.window_encoders, self.ca_modules, eeg_windows)
):
# Extract CNN features
cnn_features = encoder(eeg_window).squeeze(-1) # [batch, 256]
# Apply constraint satisfaction dynamics
features, ca_state, W_state = ca_module(
cnn_features,
prev_ca=ca_state,
prev_W=W_state
)
window_features.append(features)
ca_history.append(ca_state.detach().cpu().numpy())
# Concatenate all window features
combined = torch.cat(window_features, dim=1)
# Final classification
logits = self.classifier(combined)
return logits, ca_history
class CalciumEEGDataset(Dataset):
"""Dataset that provides EEG data split by time windows"""
def __init__(self, coco_path, annotations_path, split='train',
max_samples=None, trials_to_average=1):
self.coco_path = Path(coco_path)
# Load dataset
print(f"Loading Alljoined ({split})...")
self.dataset = load_dataset("Alljoined/05_125", split=split, streaming=False)
if max_samples:
self.dataset = self.dataset.select(range(min(int(max_samples), len(self.dataset))))
# Load COCO annotations
print(f"Loading COCO annotations...")
with open(annotations_path, 'r') as f:
coco_data = json.load(f)
self.image_categories = defaultdict(set)
for ann in coco_data['annotations']:
img_id = ann['image_id']
if ann['category_id'] in CATEGORY_NAMES:
self.image_categories[img_id].add(ann['category_id'])
# Pre-cache samples
print("Pre-caching EEG data...")
self.samples = []
for idx, sample in enumerate(self.dataset):
coco_id = sample['coco_id']
if coco_id in self.image_categories and len(self.image_categories[coco_id]) > 0:
label = torch.zeros(len(TARGET_CATEGORIES))
for cat_id in self.image_categories[coco_id]:
if cat_id in CATEGORY_NAMES:
cat_idx = list(TARGET_CATEGORIES.values()).index(cat_id)
label[cat_idx] = 1.0
if label.sum() > 0:
self.samples.append((idx, label))
print(f"Cached {len(self.samples)} samples")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample_idx, label = self.samples[idx]
sample = self.dataset[sample_idx]
eeg_data = np.array(sample['EEG'], dtype=np.float32)
# Extract time windows
eeg_windows = []
for start_ms, end_ms, _ in TIME_WINDOWS:
start_idx = int((start_ms / 1000.0) * EEG_SAMPLE_RATE)
end_idx = int((end_ms / 1000.0) * EEG_SAMPLE_RATE)
if eeg_data.shape[1] >= end_idx:
window = eeg_data[:, start_idx:end_idx]
else:
window = eeg_data[:, start_idx:]
# Pad if needed
if window.shape[1] < (end_idx - start_idx):
pad_width = (end_idx - start_idx) - window.shape[1]
window = np.pad(window, ((0,0), (0, pad_width)), mode='edge')
# Normalize
window = (window - window.mean(axis=1, keepdims=True)) / \
(window.std(axis=1, keepdims=True) + 1e-8)
eeg_windows.append(torch.from_numpy(window).float())
return eeg_windows, label
class CalciumEEGTrainerGUI(tk.Tk):
def __init__(self):
super().__init__()
self.title("Calcium-Bridged Temporal EEG Decoder V2")
self.geometry("1200x850")
self.coco_path = ""
self.annotations_path = ""
self.train_thread = None
self.stop_flag = threading.Event()
self.log_queue = queue.Queue()
self.setup_gui()
self.process_logs()
def setup_gui(self):
# Title
title = tk.Label(self, text="Calcium-Bridged Temporal EEG Decoder (V2 - Extended Window)",
font=("Arial", 14, "bold"))
title.pack(pady=10)
info = tk.Label(self,
text="Sequential constraint satisfaction across 4 ERP time windows up to 550ms\n"
"Calcium/W state from early windows primes later windows",
fg="blue", font=("Arial", 9))
info.pack(pady=5)
# Paths
path_frame = ttk.LabelFrame(self, text="Dataset")
path_frame.pack(pady=5, padx=10, fill=tk.X)
tk.Label(path_frame, text="COCO:").grid(row=0, column=0, sticky=tk.W, padx=5, pady=3)
self.coco_var = tk.StringVar()
ttk.Entry(path_frame, textvariable=self.coco_var, width=50).grid(row=0, column=1, padx=5)
ttk.Button(path_frame, text="Browse", command=self.browse_coco).grid(row=0, column=2)
tk.Label(path_frame, text="Annotations:").grid(row=1, column=0, sticky=tk.W, padx=5, pady=3)
self.ann_var = tk.StringVar()
ttk.Entry(path_frame, textvariable=self.ann_var, width=50).grid(row=1, column=1, padx=5)
ttk.Button(path_frame, text="Browse", command=self.browse_ann).grid(row=1, column=2)
# Settings
settings_frame = ttk.LabelFrame(self, text="Training Settings")
settings_frame.pack(pady=5, padx=10, fill=tk.X)
tk.Label(settings_frame, text="Max Samples:").grid(row=0, column=0, padx=5)
self.max_var = tk.IntVar(value=3000)
tk.Spinbox(settings_frame, from_=1000, to=10000, increment=1000,
textvariable=self.max_var, width=10).grid(row=0, column=1)
tk.Label(settings_frame, text="Epochs:").grid(row=0, column=2, padx=5)
self.epochs_var = tk.IntVar(value=100)
tk.Spinbox(settings_frame, from_=50, to=500, increment=50,
textvariable=self.epochs_var, width=10).grid(row=0, column=3)
# --- CRITICAL CHANGE V2 ---
# Updated GUI to reflect the new 4-stage process.
windows_frame = ttk.LabelFrame(self, text="Constraint Satisfaction Stages")
windows_frame.pack(pady=5, padx=10, fill=tk.X)
for start, end, label in TIME_WINDOWS:
desc = {
"EarlyVisual": "Low-level visual features (edges, textures)",
"MidFeature": "Mid-level binding (parts, shapes)",
"LateSemantic": "High-level semantics (concepts, context)",
"CognitiveEvaluation": "Memory, context check, final decision"
}
tk.Label(windows_frame,
text=f"{label} ({start}-{end}ms): {desc[label]}",
font=("Courier", 9)).pack(anchor=tk.W, padx=10, pady=2)
# Buttons
btn_frame = tk.Frame(self)
btn_frame.pack(pady=10)
self.train_btn = tk.Button(btn_frame, text="Train Extended Model (V2)",
command=self.start_train,
bg="#4CAF50", fg="white", font=("Arial", 10, "bold"))
self.train_btn.pack(side=tk.LEFT, padx=5)
self.stop_btn = tk.Button(btn_frame, text="Stop",
command=self.stop_train,
bg="#f44336", fg="white",
state=tk.DISABLED)
self.stop_btn.pack(side=tk.LEFT, padx=5)
# Progress
self.progress = ttk.Progressbar(self, mode='determinate')
self.progress.pack(fill=tk.X, padx=10, pady=5)
# Log
log_frame = ttk.LabelFrame(self, text="Training Log")
log_frame.pack(pady=5, padx=10, fill=tk.BOTH, expand=True)
self.log_text = tk.Text(log_frame, height=20, bg='black', fg='lightgreen',
font=('Courier', 8))
self.log_text.pack(fill=tk.BOTH, expand=True)
def browse_coco(self):
path = filedialog.askdirectory()
if path:
self.coco_var.set(path)
self.coco_path = path
def browse_ann(self):
path = filedialog.askopenfilename(filetypes=[("JSON", "*.json")])
if path:
self.ann_var.set(path)
self.annotations_path = path
def log(self, msg):
self.log_queue.put(msg)
def process_logs(self):
try:
while not self.log_queue.empty():
msg = self.log_queue.get_nowait()
self.log_text.insert(tk.END, msg + "\n")
self.log_text.see(tk.END)
except queue.Empty:
pass
self.after(100, self.process_logs)
def start_train(self):
if not self.coco_path or not self.annotations_path:
messagebox.showerror("Error", "Select paths first")
return
self.stop_flag.clear()
self.train_btn.config(state=tk.DISABLED)
self.stop_btn.config(state=tk.NORMAL)
self.train_thread = threading.Thread(target=self._train_model, daemon=True)
self.train_thread.start()
def stop_train(self):
self.stop_flag.set()
def _train_model(self):
try:
self.log("="*70)
self.log("CALCIUM-BRIDGED TEMPORAL EEG DECODER (V2 - Extended Window)")
self.log("="*70)
self.log("\nConcept: Sequential constraint satisfaction across FOUR time windows")
self.log("Now capturing late-stage cognitive evaluation signals up to 550ms\n")
# Create dataset
dataset = CalciumEEGDataset(
self.coco_path,
self.annotations_path,
'train',
self.max_var.get()
)
total = len(dataset)
train_size = int(0.8 * total)
val_size = total - train_size
train_set, val_set = torch.utils.data.random_split(
dataset, [train_size, val_size],
generator=torch.Generator().manual_seed(42)
)
self.log(f"Train: {train_size}, Val: {val_size}")
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE,
shuffle=False, num_workers=0, pin_memory=True)
# Create model (will be automatically sized for 4 windows)
model = TemporalConstraintEEGModel().to(DEVICE)
self.log(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.BCEWithLogitsLoss()
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)
best_val_loss = float('inf')
for epoch in range(self.epochs_var.get()):
if self.stop_flag.is_set():
break
# Train
model.train()
train_loss = 0
for eeg_windows, labels in train_loader:
if self.stop_flag.is_set():
break
eeg_windows = [w.to(DEVICE) for w in eeg_windows]
labels = labels.to(DEVICE)
optimizer.zero_grad()
logits, _ = model(eeg_windows)
loss = criterion(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
# Validate
model.eval()
val_loss = 0
with torch.no_grad():
for eeg_windows, labels in val_loader:
eeg_windows = [w.to(DEVICE) for w in eeg_windows]
labels = labels.to(DEVICE)
logits, _ = model(eeg_windows)
loss = criterion(logits, labels)
val_loss += loss.item()
train_loss /= len(train_loader)
val_loss /= len(val_loader)
scheduler.step()
self.progress['value'] = ((epoch + 1) / self.epochs_var.get()) * 100
if epoch % 5 == 0:
self.log(f"Epoch {epoch+1}/{self.epochs_var.get()}: "
f"TrLoss={train_loss:.4f} ValLoss={val_loss:.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save({
'model_state_dict': model.state_dict(),
'val_loss': val_loss,
'epoch': epoch
}, "calcium_bridge_eeg_model_v2.pth") # Save as V2
if epoch % 5 == 0:
self.log(f" -> Saved V2 model (val_loss={val_loss:.4f})")
self.log("\n" + "="*70)
self.log("TRAINING COMPLETE")
self.log(f"Best Val Loss: {best_val_loss:.4f}")
self.log("="*70)
except Exception as e:
self.log(f"ERROR: {e}")
import traceback
self.log(traceback.format_exc())
finally:
self.train_btn.config(state=tk.NORMAL)
self.stop_btn.config(state=tk.DISABLED)
if __name__ == "__main__":
app = CalciumEEGTrainerGUI()
app.mainloop()