Upload 3 files
Browse files- .gitattributes +1 -1
- pic.png +3 -0
- pkas_cal_trainer_gemini.py +540 -0
- pkas_cal_viewer_gemini2.py +349 -0
.gitattributes
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.pth filter=lfs diff=lfs merge=lfs -textpic.png filter=lfs diff=lfs merge=lfs -text
|
pic.png
ADDED
|
Git LFS Details
|
pkas_cal_trainer_gemini.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Calcium-Bridged Temporal EEG Decoder (V2 - Extended Time Window)
|
| 3 |
+
Integrates Phase-Calcium-Latent constraint satisfaction dynamics with EEG temporal windows.
|
| 4 |
+
|
| 5 |
+
Core Concept: Each EEG time window is processed by a constraint solver whose
|
| 6 |
+
calcium/W state carries over to initialize the next window, modeling how the brain
|
| 7 |
+
sequentially satisfies perceptual constraints.
|
| 8 |
+
|
| 9 |
+
V2 Update: The time window has been extended to 550ms based on ERP analysis from the
|
| 10 |
+
Alljoined1 paper, adding a 'CognitiveEvaluation' stage to capture late-stage
|
| 11 |
+
semantic and working memory signals.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import json
|
| 16 |
+
import tkinter as tk
|
| 17 |
+
from tkinter import ttk, filedialog, messagebox
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.optim as optim
|
| 21 |
+
from torch.utils.data import Dataset, DataLoader
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import numpy as np
|
| 24 |
+
import threading
|
| 25 |
+
import queue
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from datasets import load_dataset
|
| 33 |
+
torch.backends.cudnn.benchmark = True
|
| 34 |
+
except ImportError as e:
|
| 35 |
+
print(f"Missing dependency: {e}")
|
| 36 |
+
exit()
|
| 37 |
+
|
| 38 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
+
EEG_SAMPLE_RATE = 512
|
| 40 |
+
BATCH_SIZE = 64
|
| 41 |
+
|
| 42 |
+
# --- CRITICAL CHANGE V2 ---
|
| 43 |
+
# Extended temporal windows to capture later cognitive processing (P300/N400/P600)
|
| 44 |
+
# This aligns the model's "attention span" with the neuroscience data.
|
| 45 |
+
TIME_WINDOWS = [
|
| 46 |
+
(50, 150, "EarlyVisual"), # Low-level visual constraints (P100)
|
| 47 |
+
(150, 250, "MidFeature"), # Mid-level binding (N170/P200)
|
| 48 |
+
(250, 350, "LateSemantic"), # High-level semantics (P300/N400 start)
|
| 49 |
+
(350, 550, "CognitiveEvaluation") # Deeper context, memory, final check (P300/P600)
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
TARGET_CATEGORIES = {
|
| 53 |
+
'elephant': 22, 'giraffe': 25, 'bear': 23, 'zebra': 24,
|
| 54 |
+
'cow': 21, 'sheep': 20, 'horse': 19, 'dog': 18, 'cat': 17, 'bird': 16,
|
| 55 |
+
'airplane': 5, 'train': 7, 'boat': 9, 'bus': 6, 'truck': 8,
|
| 56 |
+
'motorcycle': 4, 'bicycle': 2, 'car': 3,
|
| 57 |
+
'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13,
|
| 58 |
+
'parking meter': 14, 'bench': 15,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
CATEGORY_NAMES = {v: k for k, v in TARGET_CATEGORIES.items()}
|
| 62 |
+
|
| 63 |
+
class CalciumAttentionModule(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Phase-Calcium-Latent dynamics for one time window.
|
| 66 |
+
Models constraint satisfaction via neuromorphic oscillator dynamics.
|
| 67 |
+
"""
|
| 68 |
+
def __init__(self, n_features, d_model=256):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.n_features = n_features
|
| 71 |
+
self.d_model = d_model
|
| 72 |
+
|
| 73 |
+
# Phase dynamics (Kuramoto-like)
|
| 74 |
+
self.phase_proj = nn.Linear(n_features, d_model)
|
| 75 |
+
|
| 76 |
+
# Calcium dynamics (gating/attention)
|
| 77 |
+
self.ca_gate = nn.Sequential(
|
| 78 |
+
nn.Linear(d_model, d_model // 2),
|
| 79 |
+
nn.Sigmoid()
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Latent coupling matrix (W) - learned constraint structure
|
| 83 |
+
self.W = nn.Parameter(torch.randn(d_model, d_model) * 0.01)
|
| 84 |
+
|
| 85 |
+
# Layer norm for stability
|
| 86 |
+
self.norm = nn.LayerNorm(d_model)
|
| 87 |
+
|
| 88 |
+
def forward(self, x, prev_ca=None, prev_W=None):
|
| 89 |
+
"""
|
| 90 |
+
x: Input features [batch, n_features]
|
| 91 |
+
prev_ca: Previous window's calcium state [batch, d_model]
|
| 92 |
+
prev_W: Previous window's coupling matrix [d_model, d_model]
|
| 93 |
+
|
| 94 |
+
Returns: features, calcium_state, W_matrix
|
| 95 |
+
"""
|
| 96 |
+
batch_size = x.size(0)
|
| 97 |
+
|
| 98 |
+
# Phase projection
|
| 99 |
+
phi = self.phase_proj(x) # [batch, d_model]
|
| 100 |
+
|
| 101 |
+
# Initialize or carry over calcium
|
| 102 |
+
if prev_ca is None:
|
| 103 |
+
ca = torch.zeros(batch_size, self.d_model, device=x.device)
|
| 104 |
+
else:
|
| 105 |
+
ca = prev_ca.clone()
|
| 106 |
+
|
| 107 |
+
# Initialize or carry over W (coupling structure)
|
| 108 |
+
W = self.W if prev_W is None else prev_W
|
| 109 |
+
|
| 110 |
+
# Calcium accumulation (coherence-based)
|
| 111 |
+
# High when features are aligned (low when conflicting)
|
| 112 |
+
coherence = torch.abs(torch.cos(phi[:, :, None] - phi[:, None, :]))
|
| 113 |
+
ca_update = torch.mean(coherence, dim=2) # [batch, d_model]
|
| 114 |
+
ca = ca * 0.95 + ca_update * 0.05 # Temporal integration
|
| 115 |
+
|
| 116 |
+
# Calcium-gated attention
|
| 117 |
+
ca_gate = self.ca_gate(ca) # [batch, d_model//2]
|
| 118 |
+
|
| 119 |
+
# Apply constraint coupling (W matrix)
|
| 120 |
+
# This is where "mutual constraint satisfaction" happens
|
| 121 |
+
coupled = torch.matmul(phi, W) # [batch, d_model]
|
| 122 |
+
|
| 123 |
+
# Gate the coupling by calcium (only attend where calcium is high)
|
| 124 |
+
ca_gate_full = torch.cat([ca_gate, ca_gate], dim=1) # Expand to d_model
|
| 125 |
+
features = coupled * ca_gate_full
|
| 126 |
+
|
| 127 |
+
# Normalize
|
| 128 |
+
features = self.norm(features + phi) # Residual connection
|
| 129 |
+
|
| 130 |
+
return features, ca, W
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class TemporalConstraintEEGModel(nn.Module):
|
| 134 |
+
"""
|
| 135 |
+
Sequential constraint satisfaction across EEG time windows.
|
| 136 |
+
Each window is a constraint solver whose state primes the next.
|
| 137 |
+
(Dynamically sized based on TIME_WINDOWS constant)
|
| 138 |
+
"""
|
| 139 |
+
def __init__(self, n_channels=64, num_classes=len(TARGET_CATEGORIES)):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.n_channels = n_channels
|
| 142 |
+
|
| 143 |
+
# CNN feature extractors for each time window
|
| 144 |
+
self.window_encoders = nn.ModuleList([
|
| 145 |
+
self._build_cnn_encoder() for _ in TIME_WINDOWS
|
| 146 |
+
])
|
| 147 |
+
|
| 148 |
+
# Calcium-attention modules for each window
|
| 149 |
+
self.ca_modules = nn.ModuleList([
|
| 150 |
+
CalciumAttentionModule(256, d_model=256) for _ in TIME_WINDOWS
|
| 151 |
+
])
|
| 152 |
+
|
| 153 |
+
# --- CRITICAL CHANGE V2 ---
|
| 154 |
+
# The input layer is now automatically larger (256 * 4) because len(TIME_WINDOWS) is 4.
|
| 155 |
+
self.classifier = nn.Sequential(
|
| 156 |
+
nn.Linear(256 * len(TIME_WINDOWS), 512),
|
| 157 |
+
nn.BatchNorm1d(512),
|
| 158 |
+
nn.GELU(),
|
| 159 |
+
nn.Dropout(0.3),
|
| 160 |
+
nn.Linear(512, num_classes)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def _build_cnn_encoder(self):
|
| 164 |
+
"""Simple CNN for one time window"""
|
| 165 |
+
return nn.Sequential(
|
| 166 |
+
nn.Conv1d(self.n_channels, 128, kernel_size=15, padding=7),
|
| 167 |
+
nn.BatchNorm1d(128),
|
| 168 |
+
nn.ELU(),
|
| 169 |
+
nn.MaxPool1d(2),
|
| 170 |
+
nn.Conv1d(128, 256, kernel_size=7, padding=3),
|
| 171 |
+
nn.BatchNorm1d(256),
|
| 172 |
+
nn.ELU(),
|
| 173 |
+
nn.AdaptiveAvgPool1d(1)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def forward(self, eeg_windows):
|
| 177 |
+
"""
|
| 178 |
+
eeg_windows: List of tensors [batch, channels, timepoints] for each window
|
| 179 |
+
|
| 180 |
+
Returns: logits, calcium_states (for visualization/analysis)
|
| 181 |
+
"""
|
| 182 |
+
batch_size = eeg_windows[0].size(0)
|
| 183 |
+
|
| 184 |
+
# Process windows sequentially with calcium/W carryover
|
| 185 |
+
window_features = []
|
| 186 |
+
ca_state = None
|
| 187 |
+
W_state = None
|
| 188 |
+
ca_history = []
|
| 189 |
+
|
| 190 |
+
for i, (encoder, ca_module, eeg_window) in enumerate(
|
| 191 |
+
zip(self.window_encoders, self.ca_modules, eeg_windows)
|
| 192 |
+
):
|
| 193 |
+
# Extract CNN features
|
| 194 |
+
cnn_features = encoder(eeg_window).squeeze(-1) # [batch, 256]
|
| 195 |
+
|
| 196 |
+
# Apply constraint satisfaction dynamics
|
| 197 |
+
features, ca_state, W_state = ca_module(
|
| 198 |
+
cnn_features,
|
| 199 |
+
prev_ca=ca_state,
|
| 200 |
+
prev_W=W_state
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
window_features.append(features)
|
| 204 |
+
ca_history.append(ca_state.detach().cpu().numpy())
|
| 205 |
+
|
| 206 |
+
# Concatenate all window features
|
| 207 |
+
combined = torch.cat(window_features, dim=1)
|
| 208 |
+
|
| 209 |
+
# Final classification
|
| 210 |
+
logits = self.classifier(combined)
|
| 211 |
+
|
| 212 |
+
return logits, ca_history
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class CalciumEEGDataset(Dataset):
|
| 216 |
+
"""Dataset that provides EEG data split by time windows"""
|
| 217 |
+
def __init__(self, coco_path, annotations_path, split='train',
|
| 218 |
+
max_samples=None, trials_to_average=1):
|
| 219 |
+
self.coco_path = Path(coco_path)
|
| 220 |
+
|
| 221 |
+
# Load dataset
|
| 222 |
+
print(f"Loading Alljoined ({split})...")
|
| 223 |
+
self.dataset = load_dataset("Alljoined/05_125", split=split, streaming=False)
|
| 224 |
+
|
| 225 |
+
if max_samples:
|
| 226 |
+
self.dataset = self.dataset.select(range(min(int(max_samples), len(self.dataset))))
|
| 227 |
+
|
| 228 |
+
# Load COCO annotations
|
| 229 |
+
print(f"Loading COCO annotations...")
|
| 230 |
+
with open(annotations_path, 'r') as f:
|
| 231 |
+
coco_data = json.load(f)
|
| 232 |
+
|
| 233 |
+
self.image_categories = defaultdict(set)
|
| 234 |
+
for ann in coco_data['annotations']:
|
| 235 |
+
img_id = ann['image_id']
|
| 236 |
+
if ann['category_id'] in CATEGORY_NAMES:
|
| 237 |
+
self.image_categories[img_id].add(ann['category_id'])
|
| 238 |
+
|
| 239 |
+
# Pre-cache samples
|
| 240 |
+
print("Pre-caching EEG data...")
|
| 241 |
+
self.samples = []
|
| 242 |
+
for idx, sample in enumerate(self.dataset):
|
| 243 |
+
coco_id = sample['coco_id']
|
| 244 |
+
if coco_id in self.image_categories and len(self.image_categories[coco_id]) > 0:
|
| 245 |
+
label = torch.zeros(len(TARGET_CATEGORIES))
|
| 246 |
+
for cat_id in self.image_categories[coco_id]:
|
| 247 |
+
if cat_id in CATEGORY_NAMES:
|
| 248 |
+
cat_idx = list(TARGET_CATEGORIES.values()).index(cat_id)
|
| 249 |
+
label[cat_idx] = 1.0
|
| 250 |
+
|
| 251 |
+
if label.sum() > 0:
|
| 252 |
+
self.samples.append((idx, label))
|
| 253 |
+
|
| 254 |
+
print(f"Cached {len(self.samples)} samples")
|
| 255 |
+
|
| 256 |
+
def __len__(self):
|
| 257 |
+
return len(self.samples)
|
| 258 |
+
|
| 259 |
+
def __getitem__(self, idx):
|
| 260 |
+
sample_idx, label = self.samples[idx]
|
| 261 |
+
sample = self.dataset[sample_idx]
|
| 262 |
+
|
| 263 |
+
eeg_data = np.array(sample['EEG'], dtype=np.float32)
|
| 264 |
+
|
| 265 |
+
# Extract time windows
|
| 266 |
+
eeg_windows = []
|
| 267 |
+
for start_ms, end_ms, _ in TIME_WINDOWS:
|
| 268 |
+
start_idx = int((start_ms / 1000.0) * EEG_SAMPLE_RATE)
|
| 269 |
+
end_idx = int((end_ms / 1000.0) * EEG_SAMPLE_RATE)
|
| 270 |
+
|
| 271 |
+
if eeg_data.shape[1] >= end_idx:
|
| 272 |
+
window = eeg_data[:, start_idx:end_idx]
|
| 273 |
+
else:
|
| 274 |
+
window = eeg_data[:, start_idx:]
|
| 275 |
+
# Pad if needed
|
| 276 |
+
if window.shape[1] < (end_idx - start_idx):
|
| 277 |
+
pad_width = (end_idx - start_idx) - window.shape[1]
|
| 278 |
+
window = np.pad(window, ((0,0), (0, pad_width)), mode='edge')
|
| 279 |
+
|
| 280 |
+
# Normalize
|
| 281 |
+
window = (window - window.mean(axis=1, keepdims=True)) / \
|
| 282 |
+
(window.std(axis=1, keepdims=True) + 1e-8)
|
| 283 |
+
|
| 284 |
+
eeg_windows.append(torch.from_numpy(window).float())
|
| 285 |
+
|
| 286 |
+
return eeg_windows, label
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class CalciumEEGTrainerGUI(tk.Tk):
|
| 290 |
+
def __init__(self):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.title("Calcium-Bridged Temporal EEG Decoder V2")
|
| 293 |
+
self.geometry("1200x850")
|
| 294 |
+
|
| 295 |
+
self.coco_path = ""
|
| 296 |
+
self.annotations_path = ""
|
| 297 |
+
self.train_thread = None
|
| 298 |
+
self.stop_flag = threading.Event()
|
| 299 |
+
self.log_queue = queue.Queue()
|
| 300 |
+
|
| 301 |
+
self.setup_gui()
|
| 302 |
+
self.process_logs()
|
| 303 |
+
|
| 304 |
+
def setup_gui(self):
|
| 305 |
+
# Title
|
| 306 |
+
title = tk.Label(self, text="Calcium-Bridged Temporal EEG Decoder (V2 - Extended Window)",
|
| 307 |
+
font=("Arial", 14, "bold"))
|
| 308 |
+
title.pack(pady=10)
|
| 309 |
+
|
| 310 |
+
info = tk.Label(self,
|
| 311 |
+
text="Sequential constraint satisfaction across 4 ERP time windows up to 550ms\n"
|
| 312 |
+
"Calcium/W state from early windows primes later windows",
|
| 313 |
+
fg="blue", font=("Arial", 9))
|
| 314 |
+
info.pack(pady=5)
|
| 315 |
+
|
| 316 |
+
# Paths
|
| 317 |
+
path_frame = ttk.LabelFrame(self, text="Dataset")
|
| 318 |
+
path_frame.pack(pady=5, padx=10, fill=tk.X)
|
| 319 |
+
|
| 320 |
+
tk.Label(path_frame, text="COCO:").grid(row=0, column=0, sticky=tk.W, padx=5, pady=3)
|
| 321 |
+
self.coco_var = tk.StringVar()
|
| 322 |
+
ttk.Entry(path_frame, textvariable=self.coco_var, width=50).grid(row=0, column=1, padx=5)
|
| 323 |
+
ttk.Button(path_frame, text="Browse", command=self.browse_coco).grid(row=0, column=2)
|
| 324 |
+
|
| 325 |
+
tk.Label(path_frame, text="Annotations:").grid(row=1, column=0, sticky=tk.W, padx=5, pady=3)
|
| 326 |
+
self.ann_var = tk.StringVar()
|
| 327 |
+
ttk.Entry(path_frame, textvariable=self.ann_var, width=50).grid(row=1, column=1, padx=5)
|
| 328 |
+
ttk.Button(path_frame, text="Browse", command=self.browse_ann).grid(row=1, column=2)
|
| 329 |
+
|
| 330 |
+
# Settings
|
| 331 |
+
settings_frame = ttk.LabelFrame(self, text="Training Settings")
|
| 332 |
+
settings_frame.pack(pady=5, padx=10, fill=tk.X)
|
| 333 |
+
|
| 334 |
+
tk.Label(settings_frame, text="Max Samples:").grid(row=0, column=0, padx=5)
|
| 335 |
+
self.max_var = tk.IntVar(value=3000)
|
| 336 |
+
tk.Spinbox(settings_frame, from_=1000, to=10000, increment=1000,
|
| 337 |
+
textvariable=self.max_var, width=10).grid(row=0, column=1)
|
| 338 |
+
|
| 339 |
+
tk.Label(settings_frame, text="Epochs:").grid(row=0, column=2, padx=5)
|
| 340 |
+
self.epochs_var = tk.IntVar(value=100)
|
| 341 |
+
tk.Spinbox(settings_frame, from_=50, to=500, increment=50,
|
| 342 |
+
textvariable=self.epochs_var, width=10).grid(row=0, column=3)
|
| 343 |
+
|
| 344 |
+
# --- CRITICAL CHANGE V2 ---
|
| 345 |
+
# Updated GUI to reflect the new 4-stage process.
|
| 346 |
+
windows_frame = ttk.LabelFrame(self, text="Constraint Satisfaction Stages")
|
| 347 |
+
windows_frame.pack(pady=5, padx=10, fill=tk.X)
|
| 348 |
+
|
| 349 |
+
for start, end, label in TIME_WINDOWS:
|
| 350 |
+
desc = {
|
| 351 |
+
"EarlyVisual": "Low-level visual features (edges, textures)",
|
| 352 |
+
"MidFeature": "Mid-level binding (parts, shapes)",
|
| 353 |
+
"LateSemantic": "High-level semantics (concepts, context)",
|
| 354 |
+
"CognitiveEvaluation": "Memory, context check, final decision"
|
| 355 |
+
}
|
| 356 |
+
tk.Label(windows_frame,
|
| 357 |
+
text=f"{label} ({start}-{end}ms): {desc[label]}",
|
| 358 |
+
font=("Courier", 9)).pack(anchor=tk.W, padx=10, pady=2)
|
| 359 |
+
|
| 360 |
+
# Buttons
|
| 361 |
+
btn_frame = tk.Frame(self)
|
| 362 |
+
btn_frame.pack(pady=10)
|
| 363 |
+
|
| 364 |
+
self.train_btn = tk.Button(btn_frame, text="Train Extended Model (V2)",
|
| 365 |
+
command=self.start_train,
|
| 366 |
+
bg="#4CAF50", fg="white", font=("Arial", 10, "bold"))
|
| 367 |
+
self.train_btn.pack(side=tk.LEFT, padx=5)
|
| 368 |
+
|
| 369 |
+
self.stop_btn = tk.Button(btn_frame, text="Stop",
|
| 370 |
+
command=self.stop_train,
|
| 371 |
+
bg="#f44336", fg="white",
|
| 372 |
+
state=tk.DISABLED)
|
| 373 |
+
self.stop_btn.pack(side=tk.LEFT, padx=5)
|
| 374 |
+
|
| 375 |
+
# Progress
|
| 376 |
+
self.progress = ttk.Progressbar(self, mode='determinate')
|
| 377 |
+
self.progress.pack(fill=tk.X, padx=10, pady=5)
|
| 378 |
+
|
| 379 |
+
# Log
|
| 380 |
+
log_frame = ttk.LabelFrame(self, text="Training Log")
|
| 381 |
+
log_frame.pack(pady=5, padx=10, fill=tk.BOTH, expand=True)
|
| 382 |
+
|
| 383 |
+
self.log_text = tk.Text(log_frame, height=20, bg='black', fg='lightgreen',
|
| 384 |
+
font=('Courier', 8))
|
| 385 |
+
self.log_text.pack(fill=tk.BOTH, expand=True)
|
| 386 |
+
|
| 387 |
+
def browse_coco(self):
|
| 388 |
+
path = filedialog.askdirectory()
|
| 389 |
+
if path:
|
| 390 |
+
self.coco_var.set(path)
|
| 391 |
+
self.coco_path = path
|
| 392 |
+
|
| 393 |
+
def browse_ann(self):
|
| 394 |
+
path = filedialog.askopenfilename(filetypes=[("JSON", "*.json")])
|
| 395 |
+
if path:
|
| 396 |
+
self.ann_var.set(path)
|
| 397 |
+
self.annotations_path = path
|
| 398 |
+
|
| 399 |
+
def log(self, msg):
|
| 400 |
+
self.log_queue.put(msg)
|
| 401 |
+
|
| 402 |
+
def process_logs(self):
|
| 403 |
+
try:
|
| 404 |
+
while not self.log_queue.empty():
|
| 405 |
+
msg = self.log_queue.get_nowait()
|
| 406 |
+
self.log_text.insert(tk.END, msg + "\n")
|
| 407 |
+
self.log_text.see(tk.END)
|
| 408 |
+
except queue.Empty:
|
| 409 |
+
pass
|
| 410 |
+
self.after(100, self.process_logs)
|
| 411 |
+
|
| 412 |
+
def start_train(self):
|
| 413 |
+
if not self.coco_path or not self.annotations_path:
|
| 414 |
+
messagebox.showerror("Error", "Select paths first")
|
| 415 |
+
return
|
| 416 |
+
|
| 417 |
+
self.stop_flag.clear()
|
| 418 |
+
self.train_btn.config(state=tk.DISABLED)
|
| 419 |
+
self.stop_btn.config(state=tk.NORMAL)
|
| 420 |
+
|
| 421 |
+
self.train_thread = threading.Thread(target=self._train_model, daemon=True)
|
| 422 |
+
self.train_thread.start()
|
| 423 |
+
|
| 424 |
+
def stop_train(self):
|
| 425 |
+
self.stop_flag.set()
|
| 426 |
+
|
| 427 |
+
def _train_model(self):
|
| 428 |
+
try:
|
| 429 |
+
self.log("="*70)
|
| 430 |
+
self.log("CALCIUM-BRIDGED TEMPORAL EEG DECODER (V2 - Extended Window)")
|
| 431 |
+
self.log("="*70)
|
| 432 |
+
self.log("\nConcept: Sequential constraint satisfaction across FOUR time windows")
|
| 433 |
+
self.log("Now capturing late-stage cognitive evaluation signals up to 550ms\n")
|
| 434 |
+
|
| 435 |
+
# Create dataset
|
| 436 |
+
dataset = CalciumEEGDataset(
|
| 437 |
+
self.coco_path,
|
| 438 |
+
self.annotations_path,
|
| 439 |
+
'train',
|
| 440 |
+
self.max_var.get()
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
total = len(dataset)
|
| 444 |
+
train_size = int(0.8 * total)
|
| 445 |
+
val_size = total - train_size
|
| 446 |
+
|
| 447 |
+
train_set, val_set = torch.utils.data.random_split(
|
| 448 |
+
dataset, [train_size, val_size],
|
| 449 |
+
generator=torch.Generator().manual_seed(42)
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
self.log(f"Train: {train_size}, Val: {val_size}")
|
| 453 |
+
|
| 454 |
+
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
|
| 455 |
+
shuffle=True, num_workers=0, pin_memory=True)
|
| 456 |
+
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE,
|
| 457 |
+
shuffle=False, num_workers=0, pin_memory=True)
|
| 458 |
+
|
| 459 |
+
# Create model (will be automatically sized for 4 windows)
|
| 460 |
+
model = TemporalConstraintEEGModel().to(DEVICE)
|
| 461 |
+
self.log(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 462 |
+
|
| 463 |
+
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
|
| 464 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 465 |
+
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)
|
| 466 |
+
|
| 467 |
+
best_val_loss = float('inf')
|
| 468 |
+
|
| 469 |
+
for epoch in range(self.epochs_var.get()):
|
| 470 |
+
if self.stop_flag.is_set():
|
| 471 |
+
break
|
| 472 |
+
|
| 473 |
+
# Train
|
| 474 |
+
model.train()
|
| 475 |
+
train_loss = 0
|
| 476 |
+
for eeg_windows, labels in train_loader:
|
| 477 |
+
if self.stop_flag.is_set():
|
| 478 |
+
break
|
| 479 |
+
|
| 480 |
+
eeg_windows = [w.to(DEVICE) for w in eeg_windows]
|
| 481 |
+
labels = labels.to(DEVICE)
|
| 482 |
+
|
| 483 |
+
optimizer.zero_grad()
|
| 484 |
+
logits, _ = model(eeg_windows)
|
| 485 |
+
loss = criterion(logits, labels)
|
| 486 |
+
loss.backward()
|
| 487 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 488 |
+
optimizer.step()
|
| 489 |
+
|
| 490 |
+
train_loss += loss.item()
|
| 491 |
+
|
| 492 |
+
# Validate
|
| 493 |
+
model.eval()
|
| 494 |
+
val_loss = 0
|
| 495 |
+
with torch.no_grad():
|
| 496 |
+
for eeg_windows, labels in val_loader:
|
| 497 |
+
eeg_windows = [w.to(DEVICE) for w in eeg_windows]
|
| 498 |
+
labels = labels.to(DEVICE)
|
| 499 |
+
logits, _ = model(eeg_windows)
|
| 500 |
+
loss = criterion(logits, labels)
|
| 501 |
+
val_loss += loss.item()
|
| 502 |
+
|
| 503 |
+
train_loss /= len(train_loader)
|
| 504 |
+
val_loss /= len(val_loader)
|
| 505 |
+
|
| 506 |
+
scheduler.step()
|
| 507 |
+
|
| 508 |
+
self.progress['value'] = ((epoch + 1) / self.epochs_var.get()) * 100
|
| 509 |
+
|
| 510 |
+
if epoch % 5 == 0:
|
| 511 |
+
self.log(f"Epoch {epoch+1}/{self.epochs_var.get()}: "
|
| 512 |
+
f"TrLoss={train_loss:.4f} ValLoss={val_loss:.4f}")
|
| 513 |
+
|
| 514 |
+
if val_loss < best_val_loss:
|
| 515 |
+
best_val_loss = val_loss
|
| 516 |
+
torch.save({
|
| 517 |
+
'model_state_dict': model.state_dict(),
|
| 518 |
+
'val_loss': val_loss,
|
| 519 |
+
'epoch': epoch
|
| 520 |
+
}, "calcium_bridge_eeg_model_v2.pth") # Save as V2
|
| 521 |
+
if epoch % 5 == 0:
|
| 522 |
+
self.log(f" -> Saved V2 model (val_loss={val_loss:.4f})")
|
| 523 |
+
|
| 524 |
+
self.log("\n" + "="*70)
|
| 525 |
+
self.log("TRAINING COMPLETE")
|
| 526 |
+
self.log(f"Best Val Loss: {best_val_loss:.4f}")
|
| 527 |
+
self.log("="*70)
|
| 528 |
+
|
| 529 |
+
except Exception as e:
|
| 530 |
+
self.log(f"ERROR: {e}")
|
| 531 |
+
import traceback
|
| 532 |
+
self.log(traceback.format_exc())
|
| 533 |
+
finally:
|
| 534 |
+
self.train_btn.config(state=tk.NORMAL)
|
| 535 |
+
self.stop_btn.config(state=tk.DISABLED)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
if __name__ == "__main__":
|
| 539 |
+
app = CalciumEEGTrainerGUI()
|
| 540 |
+
app.mainloop()
|
pkas_cal_viewer_gemini2.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Calcium-Bridge EEG Constraint Viewer (V2.1 - Fixed)
|
| 3 |
+
Visualizes how constraint satisfaction unfolds across four temporal windows up to 550ms.
|
| 4 |
+
|
| 5 |
+
Shows:
|
| 6 |
+
1. Original COCO image
|
| 7 |
+
2. EEG heatmaps for each of the 4 time windows
|
| 8 |
+
3. Calcium "attention" evolution (what the model focuses on at each stage)
|
| 9 |
+
4. Top predictions crystallizing across the 4 windows
|
| 10 |
+
|
| 11 |
+
V2.1 Fixes:
|
| 12 |
+
- Corrected 'figsize' argument placement during figure creation.
|
| 13 |
+
- Corrected colorbar creation to use the figure object directly, resolving warnings.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import tkinter as tk
|
| 18 |
+
from tkinter import filedialog, messagebox, ttk
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from PIL import Image, ImageTk
|
| 24 |
+
import matplotlib.pyplot as plt
|
| 25 |
+
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
| 26 |
+
import json
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from collections import defaultdict
|
| 29 |
+
import random
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from datasets import load_dataset
|
| 33 |
+
except ImportError:
|
| 34 |
+
print("Missing datasets library.")
|
| 35 |
+
exit()
|
| 36 |
+
|
| 37 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 38 |
+
EEG_SAMPLE_RATE = 512
|
| 39 |
+
|
| 40 |
+
TIME_WINDOWS = [
|
| 41 |
+
(50, 150, "EarlyVisual"),
|
| 42 |
+
(150, 250, "MidFeature"),
|
| 43 |
+
(250, 350, "LateSemantic"),
|
| 44 |
+
(350, 550, "CognitiveEvaluation")
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
TARGET_CATEGORIES = {
|
| 48 |
+
'elephant': 22, 'giraffe': 25, 'bear': 23, 'zebra': 24,
|
| 49 |
+
'cow': 21, 'sheep': 20, 'horse': 19, 'dog': 18, 'cat': 17, 'bird': 16,
|
| 50 |
+
'airplane': 5, 'train': 7, 'boat': 9, 'bus': 6, 'truck': 8,
|
| 51 |
+
'motorcycle': 4, 'bicycle': 2, 'car': 3,
|
| 52 |
+
'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13,
|
| 53 |
+
'parking meter': 14, 'bench': 15,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
CATEGORY_NAMES = {v: k for k, v in TARGET_CATEGORIES.items()}
|
| 57 |
+
TARGET_IDS = set(TARGET_CATEGORIES.values())
|
| 58 |
+
ALL_COCO_IDS = list(range(1, 91))
|
| 59 |
+
EXCLUDED_IDS = set(ALL_COCO_IDS) - TARGET_IDS
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# === MODEL ARCHITECTURE (Must match V2 training code) ===
|
| 63 |
+
|
| 64 |
+
class CalciumAttentionModule(nn.Module):
|
| 65 |
+
def __init__(self, n_features, d_model=256):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.n_features = n_features
|
| 68 |
+
self.d_model = d_model
|
| 69 |
+
self.phase_proj = nn.Linear(n_features, d_model)
|
| 70 |
+
self.ca_gate = nn.Sequential(
|
| 71 |
+
nn.Linear(d_model, d_model // 2),
|
| 72 |
+
nn.Sigmoid()
|
| 73 |
+
)
|
| 74 |
+
self.W = nn.Parameter(torch.randn(d_model, d_model) * 0.01)
|
| 75 |
+
self.norm = nn.LayerNorm(d_model)
|
| 76 |
+
|
| 77 |
+
def forward(self, x, prev_ca=None, prev_W=None):
|
| 78 |
+
batch_size = x.size(0)
|
| 79 |
+
phi = self.phase_proj(x)
|
| 80 |
+
|
| 81 |
+
if prev_ca is None:
|
| 82 |
+
ca = torch.zeros(batch_size, self.d_model, device=x.device)
|
| 83 |
+
else:
|
| 84 |
+
ca = prev_ca.clone()
|
| 85 |
+
|
| 86 |
+
W = self.W if prev_W is None else prev_W
|
| 87 |
+
|
| 88 |
+
coherence = torch.abs(torch.cos(phi[:, :, None] - phi[:, None, :]))
|
| 89 |
+
ca_update = torch.mean(coherence, dim=2)
|
| 90 |
+
ca = ca * 0.95 + ca_update * 0.05
|
| 91 |
+
|
| 92 |
+
ca_gate = self.ca_gate(ca)
|
| 93 |
+
coupled = torch.matmul(phi, W)
|
| 94 |
+
ca_gate_full = torch.cat([ca_gate, ca_gate], dim=1)
|
| 95 |
+
features = coupled * ca_gate_full
|
| 96 |
+
features = self.norm(features + phi)
|
| 97 |
+
|
| 98 |
+
return features, ca, W
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class TemporalConstraintEEGModel(nn.Module):
|
| 102 |
+
def __init__(self, n_channels=64, num_classes=len(TARGET_CATEGORIES)):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.n_channels = n_channels
|
| 105 |
+
|
| 106 |
+
self.window_encoders = nn.ModuleList([
|
| 107 |
+
self._build_cnn_encoder() for _ in TIME_WINDOWS
|
| 108 |
+
])
|
| 109 |
+
|
| 110 |
+
self.ca_modules = nn.ModuleList([
|
| 111 |
+
CalciumAttentionModule(256, d_model=256) for _ in TIME_WINDOWS
|
| 112 |
+
])
|
| 113 |
+
|
| 114 |
+
self.classifier = nn.Sequential(
|
| 115 |
+
nn.Linear(256 * len(TIME_WINDOWS), 512),
|
| 116 |
+
nn.BatchNorm1d(512),
|
| 117 |
+
nn.GELU(),
|
| 118 |
+
nn.Dropout(0.3),
|
| 119 |
+
nn.Linear(512, num_classes)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def _build_cnn_encoder(self):
|
| 123 |
+
return nn.Sequential(
|
| 124 |
+
nn.Conv1d(self.n_channels, 128, kernel_size=15, padding=7),
|
| 125 |
+
nn.BatchNorm1d(128),
|
| 126 |
+
nn.ELU(),
|
| 127 |
+
nn.MaxPool1d(2),
|
| 128 |
+
nn.Conv1d(128, 256, kernel_size=7, padding=3),
|
| 129 |
+
nn.BatchNorm1d(256),
|
| 130 |
+
nn.ELU(),
|
| 131 |
+
nn.AdaptiveAvgPool1d(1)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def forward(self, eeg_windows, return_intermediates=False):
|
| 135 |
+
batch_size = eeg_windows[0].size(0)
|
| 136 |
+
|
| 137 |
+
window_features, ca_history, W_history, window_logits_list = [], [], [], []
|
| 138 |
+
ca_state, W_state = None, None
|
| 139 |
+
|
| 140 |
+
for i, (encoder, ca_module, eeg_window) in enumerate(
|
| 141 |
+
zip(self.window_encoders, self.ca_modules, eeg_windows)
|
| 142 |
+
):
|
| 143 |
+
cnn_features = encoder(eeg_window).squeeze(-1)
|
| 144 |
+
features, ca_state, W_state = ca_module(cnn_features, ca_state, W_state)
|
| 145 |
+
|
| 146 |
+
window_features.append(features)
|
| 147 |
+
if return_intermediates:
|
| 148 |
+
ca_history.append(ca_state.detach())
|
| 149 |
+
W_history.append(W_state.detach())
|
| 150 |
+
|
| 151 |
+
padded_features = window_features + [
|
| 152 |
+
torch.zeros_like(features) for _ in range(len(TIME_WINDOWS) - len(window_features))
|
| 153 |
+
]
|
| 154 |
+
intermediate_logits = self.classifier(torch.cat(padded_features, dim=1))
|
| 155 |
+
window_logits_list.append(intermediate_logits.detach())
|
| 156 |
+
|
| 157 |
+
combined = torch.cat(window_features, dim=1)
|
| 158 |
+
logits = self.classifier(combined)
|
| 159 |
+
|
| 160 |
+
if return_intermediates:
|
| 161 |
+
return logits, ca_history, W_history, window_logits_list
|
| 162 |
+
return logits, ca_history
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# === DATA LOADER ===
|
| 166 |
+
class FilteredTestDataset:
|
| 167 |
+
def __init__(self, annotations_path, max_samples=1000):
|
| 168 |
+
print("Loading and filtering test dataset...")
|
| 169 |
+
self.eeg_dataset = load_dataset("Alljoined/05_125", split='test', streaming=False).select(range(max_samples))
|
| 170 |
+
with open(annotations_path, 'r') as f:
|
| 171 |
+
coco_data = json.load(f)
|
| 172 |
+
|
| 173 |
+
image_annotations = defaultdict(set)
|
| 174 |
+
for ann in coco_data['annotations']:
|
| 175 |
+
image_annotations[ann['image_id']].add(ann['category_id'])
|
| 176 |
+
|
| 177 |
+
self.filtered_samples = []
|
| 178 |
+
for idx, sample in enumerate(self.eeg_dataset):
|
| 179 |
+
ann_ids = image_annotations.get(sample['coco_id'], set())
|
| 180 |
+
if not any(cat_id in EXCLUDED_IDS for cat_id in ann_ids) and any(cat_id in TARGET_IDS for cat_id in ann_ids):
|
| 181 |
+
self.filtered_samples.append({
|
| 182 |
+
'coco_id': sample['coco_id'],
|
| 183 |
+
'eeg_data': np.array(sample['EEG'], dtype=np.float32)
|
| 184 |
+
})
|
| 185 |
+
print(f"Loaded {len(self.filtered_samples)} filtered test samples.")
|
| 186 |
+
if not self.filtered_samples: raise RuntimeError("No suitable test samples found.")
|
| 187 |
+
|
| 188 |
+
def get_eeg_windows(self, sample_info):
|
| 189 |
+
eeg_data = sample_info['eeg_data']
|
| 190 |
+
eeg_windows = []
|
| 191 |
+
for start_ms, end_ms, _ in TIME_WINDOWS:
|
| 192 |
+
start_idx, end_idx = int(start_ms / 1000 * EEG_SAMPLE_RATE), int(end_ms / 1000 * EEG_SAMPLE_RATE)
|
| 193 |
+
n_timepoints = end_idx - start_idx
|
| 194 |
+
window = eeg_data[:, start_idx:end_idx] if eeg_data.shape[1] >= end_idx else eeg_data[:, start_idx:]
|
| 195 |
+
|
| 196 |
+
if window.shape[1] != n_timepoints:
|
| 197 |
+
pad_width = n_timepoints - window.shape[1]
|
| 198 |
+
window = np.pad(window, ((0,0), (0, pad_width)), 'edge') if pad_width > 0 else window[:, :n_timepoints]
|
| 199 |
+
|
| 200 |
+
window = (window - window.mean(axis=1, keepdims=True)) / (window.std(axis=1, keepdims=True) + 1e-8)
|
| 201 |
+
eeg_windows.append(window)
|
| 202 |
+
return eeg_windows
|
| 203 |
+
|
| 204 |
+
def get_random_sample_info(self):
|
| 205 |
+
return random.choice(self.filtered_samples)
|
| 206 |
+
|
| 207 |
+
# === VIEWER APPLICATION ===
|
| 208 |
+
class CalciumBridgeViewer(tk.Tk):
|
| 209 |
+
def __init__(self):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.title("Calcium-Bridge EEG Constraint Viewer V2 (Extended Window)")
|
| 212 |
+
self.geometry("2000x1000")
|
| 213 |
+
self.model, self.test_data = None, None
|
| 214 |
+
self.setup_gui()
|
| 215 |
+
|
| 216 |
+
def setup_gui(self):
|
| 217 |
+
control_frame = ttk.Frame(self); control_frame.pack(pady=10, padx=10, fill=tk.X)
|
| 218 |
+
ttk.Label(control_frame, text="COCO Path:").pack(side=tk.LEFT, padx=5)
|
| 219 |
+
self.coco_var = tk.StringVar(); ttk.Entry(control_frame, textvariable=self.coco_var, width=20).pack(side=tk.LEFT, padx=2)
|
| 220 |
+
ttk.Button(control_frame, text="Browse", command=self.browse_coco).pack(side=tk.LEFT, padx=5)
|
| 221 |
+
ttk.Label(control_frame, text="Annotations:").pack(side=tk.LEFT, padx=5)
|
| 222 |
+
self.ann_var = tk.StringVar(); ttk.Entry(control_frame, textvariable=self.ann_var, width=20).pack(side=tk.LEFT, padx=2)
|
| 223 |
+
ttk.Button(control_frame, text="Browse", command=self.browse_ann).pack(side=tk.LEFT, padx=5)
|
| 224 |
+
ttk.Button(control_frame, text="Load V2 Model", command=self.load_model).pack(side=tk.LEFT, padx=20)
|
| 225 |
+
self.test_btn = ttk.Button(control_frame, text="Test Random Sample", command=self.test_sample, state=tk.DISABLED); self.test_btn.pack(side=tk.LEFT, padx=5)
|
| 226 |
+
self.status_label = tk.Label(control_frame, text="Model: Not loaded", fg="gray"); self.status_label.pack(side=tk.LEFT, padx=20)
|
| 227 |
+
|
| 228 |
+
main_paned = ttk.PanedWindow(self, orient=tk.HORIZONTAL); main_paned.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
| 229 |
+
image_frame = ttk.Frame(main_paned, width=400); main_paned.add(image_frame, weight=0)
|
| 230 |
+
ttk.Label(image_frame, text="COCO Image", font=("Arial", 12, "bold")).pack(pady=5)
|
| 231 |
+
self.image_canvas = tk.Canvas(image_frame, width=400, height=400, bg='lightgray'); self.image_canvas.pack()
|
| 232 |
+
self.coco_id_label = ttk.Label(image_frame, text="COCO ID: N/A"); self.coco_id_label.pack(pady=5)
|
| 233 |
+
|
| 234 |
+
self.notebook = ttk.Notebook(main_paned); main_paned.add(self.notebook, weight=1)
|
| 235 |
+
self.create_tabs()
|
| 236 |
+
|
| 237 |
+
def create_tabs(self):
|
| 238 |
+
self.constraint_fig, self.constraint_canvas = self.create_tab("Constraint Satisfaction", "How predictions crystallize as constraints are satisfied")
|
| 239 |
+
self.calcium_fig, self.calcium_canvas = self.create_tab("Calcium Attention", "Calcium state evolution: What the model 'focuses on' at each stage")
|
| 240 |
+
self.eeg_fig, self.eeg_canvas = self.create_tab("EEG Heatmaps", "Raw EEG signals for each time window")
|
| 241 |
+
|
| 242 |
+
def create_tab(self, title, description):
|
| 243 |
+
tab = ttk.Frame(self.notebook); self.notebook.add(tab, text=title)
|
| 244 |
+
ttk.Label(tab, text=description, font=("Arial", 11)).pack(pady=5)
|
| 245 |
+
fig = plt.Figure()
|
| 246 |
+
canvas = FigureCanvasTkAgg(fig, tab); canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
|
| 247 |
+
return fig, canvas
|
| 248 |
+
|
| 249 |
+
def browse_coco(self):
|
| 250 |
+
path = filedialog.askdirectory(); self.coco_var.set(path); self.coco_path = path
|
| 251 |
+
|
| 252 |
+
def browse_ann(self):
|
| 253 |
+
path = filedialog.askopenfilename(filetypes=[("JSON", "*.json")]); self.ann_var.set(path); self.annotations_path = path
|
| 254 |
+
|
| 255 |
+
def load_model(self):
|
| 256 |
+
model_path = filedialog.askopenfilename(filetypes=[("PyTorch Model", "*.pth")], title="Select calcium_bridge_eeg_model_v2.pth")
|
| 257 |
+
if not model_path or not self.annotations_path: return
|
| 258 |
+
try:
|
| 259 |
+
checkpoint = torch.load(model_path, map_location=DEVICE)
|
| 260 |
+
self.model = TemporalConstraintEEGModel().to(DEVICE)
|
| 261 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 262 |
+
self.model.eval()
|
| 263 |
+
self.test_data = FilteredTestDataset(self.annotations_path)
|
| 264 |
+
self.status_label.config(text="Model: V2 Loaded ✓", fg="green")
|
| 265 |
+
self.test_btn.config(state=tk.NORMAL)
|
| 266 |
+
except Exception as e: messagebox.showerror("Error", f"Failed to load model:\n{e}"); print(traceback.format_exc())
|
| 267 |
+
|
| 268 |
+
def _fetch_image(self, coco_id):
|
| 269 |
+
formatted_id = f"{coco_id:012d}.jpg"
|
| 270 |
+
for s in ["train2017", "val2017", "test2017"]:
|
| 271 |
+
path = os.path.join(self.coco_path, s, formatted_id)
|
| 272 |
+
if os.path.exists(path): return Image.open(path).convert("RGB")
|
| 273 |
+
return None
|
| 274 |
+
|
| 275 |
+
def test_sample(self):
|
| 276 |
+
if not self.model: return
|
| 277 |
+
try:
|
| 278 |
+
sample_info = self.test_data.get_random_sample_info()
|
| 279 |
+
image = self._fetch_image(sample_info['coco_id'])
|
| 280 |
+
if image: self.display_image(image, sample_info['coco_id'])
|
| 281 |
+
|
| 282 |
+
eeg_windows_np = self.test_data.get_eeg_windows(sample_info)
|
| 283 |
+
eeg_windows = [torch.from_numpy(w).unsqueeze(0).to(DEVICE) for w in eeg_windows_np]
|
| 284 |
+
|
| 285 |
+
with torch.no_grad():
|
| 286 |
+
logits, ca_history, _, window_logits = self.model(eeg_windows, return_intermediates=True)
|
| 287 |
+
|
| 288 |
+
self.visualize_constraint_satisfaction(window_logits, logits)
|
| 289 |
+
self.visualize_calcium_evolution(ca_history)
|
| 290 |
+
self.visualize_eeg_heatmaps(eeg_windows_np)
|
| 291 |
+
except Exception as e: messagebox.showerror("Error", f"Failed to process sample:\n{e}"); print(traceback.format_exc())
|
| 292 |
+
|
| 293 |
+
def display_image(self, image, coco_id):
|
| 294 |
+
ratio = min(400/image.width, 400/image.height)
|
| 295 |
+
resized = image.resize((int(image.width * ratio), int(image.height * ratio)), Image.LANCZOS)
|
| 296 |
+
self.pil_image_tk = ImageTk.PhotoImage(resized)
|
| 297 |
+
self.image_canvas.create_image(200, 200, image=self.pil_image_tk)
|
| 298 |
+
self.coco_id_label.config(text=f"COCO ID: {coco_id}")
|
| 299 |
+
|
| 300 |
+
def visualize_constraint_satisfaction(self, window_logits, final_logits):
|
| 301 |
+
self.constraint_fig.clear()
|
| 302 |
+
cat_list = list(TARGET_CATEGORIES.keys())
|
| 303 |
+
n_windows = len(window_logits)
|
| 304 |
+
final_probs = torch.sigmoid(final_logits).squeeze(0).cpu().numpy()
|
| 305 |
+
top_indices = np.argsort(final_probs)[::-1][:10]
|
| 306 |
+
axes = self.constraint_fig.subplots(1, n_windows + 1)
|
| 307 |
+
|
| 308 |
+
for i, (ax, wl) in enumerate(zip(axes[:-1], window_logits)):
|
| 309 |
+
probs = torch.sigmoid(wl).squeeze(0).cpu().numpy()[top_indices]
|
| 310 |
+
ax.barh([cat_list[idx] for idx in top_indices], probs, color='steelblue')
|
| 311 |
+
ax.set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
|
| 312 |
+
ax.set_xlim(0, 1); ax.invert_yaxis(); ax.tick_params(axis='y', labelsize=8)
|
| 313 |
+
|
| 314 |
+
axes[-1].barh([cat_list[idx] for idx in top_indices], final_probs[top_indices], color='darkgreen')
|
| 315 |
+
axes[-1].set_title("Final\n(Combined)", fontsize=10); axes[-1].set_xlim(0, 1); axes[-1].invert_yaxis(); axes[-1].tick_params(axis='y', labelsize=8)
|
| 316 |
+
self.constraint_fig.suptitle("Constraint Satisfaction: Predictions Crystallizing Over Time", fontsize=14, fontweight='bold')
|
| 317 |
+
self.constraint_fig.tight_layout(); self.constraint_canvas.draw()
|
| 318 |
+
|
| 319 |
+
def visualize_calcium_evolution(self, ca_history):
|
| 320 |
+
self.calcium_fig.clear()
|
| 321 |
+
n_windows = len(ca_history)
|
| 322 |
+
axes = self.calcium_fig.subplots(2, n_windows)
|
| 323 |
+
|
| 324 |
+
for i, ca_state in enumerate(ca_history):
|
| 325 |
+
ca_np = ca_state.squeeze(0).cpu().numpy()
|
| 326 |
+
top_20_idx = np.argsort(ca_np)[::-1][:20]
|
| 327 |
+
axes[0, i].plot(ca_np, 'r'); axes[0, i].fill_between(range(len(ca_np)), ca_np, color='r', alpha=0.3)
|
| 328 |
+
axes[0, i].set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
|
| 329 |
+
axes[1, i].barh([f"F{idx}" for idx in top_20_idx], ca_np[top_20_idx], color='darkred')
|
| 330 |
+
axes[1, i].invert_yaxis(); axes[1, i].tick_params(axis='y', labelsize=7)
|
| 331 |
+
self.calcium_fig.suptitle("Calcium Attention: What the Model Focuses On", fontsize=14, fontweight='bold')
|
| 332 |
+
self.calcium_fig.tight_layout(); self.calcium_canvas.draw()
|
| 333 |
+
|
| 334 |
+
def visualize_eeg_heatmaps(self, eeg_windows_np):
|
| 335 |
+
self.eeg_fig.clear()
|
| 336 |
+
n_windows = len(eeg_windows_np)
|
| 337 |
+
axes = self.eeg_fig.subplots(1, n_windows)
|
| 338 |
+
|
| 339 |
+
for i, (ax, eeg_data) in enumerate(zip(axes, eeg_windows_np)):
|
| 340 |
+
im = ax.imshow(eeg_data, aspect='auto', cmap='RdBu_r', vmin=-3, vmax=3)
|
| 341 |
+
ax.set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
|
| 342 |
+
if i == 0: ax.set_ylabel("Channel")
|
| 343 |
+
self.eeg_fig.colorbar(im, ax=ax) # CORRECTED
|
| 344 |
+
self.eeg_fig.suptitle("Raw EEG Signals by Time Window", fontsize=14, fontweight='bold')
|
| 345 |
+
self.eeg_fig.tight_layout(); self.eeg_canvas.draw()
|
| 346 |
+
|
| 347 |
+
if __name__ == "__main__":
|
| 348 |
+
app = CalciumBridgeViewer()
|
| 349 |
+
app.mainloop()
|