|
|
"""
|
|
|
Calcium-Bridge EEG Constraint Viewer (V2.1 - Fixed)
|
|
|
Visualizes how constraint satisfaction unfolds across four temporal windows up to 550ms.
|
|
|
|
|
|
Shows:
|
|
|
1. Original COCO image
|
|
|
2. EEG heatmaps for each of the 4 time windows
|
|
|
3. Calcium "attention" evolution (what the model focuses on at each stage)
|
|
|
4. Top predictions crystallizing across the 4 windows
|
|
|
|
|
|
V2.1 Fixes:
|
|
|
- Corrected 'figsize' argument placement during figure creation.
|
|
|
- Corrected colorbar creation to use the figure object directly, resolving warnings.
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import tkinter as tk
|
|
|
from tkinter import filedialog, messagebox, ttk
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from PIL import Image, ImageTk
|
|
|
import matplotlib.pyplot as plt
|
|
|
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
from collections import defaultdict
|
|
|
import random
|
|
|
|
|
|
try:
|
|
|
from datasets import load_dataset
|
|
|
except ImportError:
|
|
|
print("Missing datasets library.")
|
|
|
exit()
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
EEG_SAMPLE_RATE = 512
|
|
|
|
|
|
TIME_WINDOWS = [
|
|
|
(50, 150, "EarlyVisual"),
|
|
|
(150, 250, "MidFeature"),
|
|
|
(250, 350, "LateSemantic"),
|
|
|
(350, 550, "CognitiveEvaluation")
|
|
|
]
|
|
|
|
|
|
TARGET_CATEGORIES = {
|
|
|
'elephant': 22, 'giraffe': 25, 'bear': 23, 'zebra': 24,
|
|
|
'cow': 21, 'sheep': 20, 'horse': 19, 'dog': 18, 'cat': 17, 'bird': 16,
|
|
|
'airplane': 5, 'train': 7, 'boat': 9, 'bus': 6, 'truck': 8,
|
|
|
'motorcycle': 4, 'bicycle': 2, 'car': 3,
|
|
|
'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13,
|
|
|
'parking meter': 14, 'bench': 15,
|
|
|
}
|
|
|
|
|
|
CATEGORY_NAMES = {v: k for k, v in TARGET_CATEGORIES.items()}
|
|
|
TARGET_IDS = set(TARGET_CATEGORIES.values())
|
|
|
ALL_COCO_IDS = list(range(1, 91))
|
|
|
EXCLUDED_IDS = set(ALL_COCO_IDS) - TARGET_IDS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CalciumAttentionModule(nn.Module):
|
|
|
def __init__(self, n_features, d_model=256):
|
|
|
super().__init__()
|
|
|
self.n_features = n_features
|
|
|
self.d_model = d_model
|
|
|
self.phase_proj = nn.Linear(n_features, d_model)
|
|
|
self.ca_gate = nn.Sequential(
|
|
|
nn.Linear(d_model, d_model // 2),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
self.W = nn.Parameter(torch.randn(d_model, d_model) * 0.01)
|
|
|
self.norm = nn.LayerNorm(d_model)
|
|
|
|
|
|
def forward(self, x, prev_ca=None, prev_W=None):
|
|
|
batch_size = x.size(0)
|
|
|
phi = self.phase_proj(x)
|
|
|
|
|
|
if prev_ca is None:
|
|
|
ca = torch.zeros(batch_size, self.d_model, device=x.device)
|
|
|
else:
|
|
|
ca = prev_ca.clone()
|
|
|
|
|
|
W = self.W if prev_W is None else prev_W
|
|
|
|
|
|
coherence = torch.abs(torch.cos(phi[:, :, None] - phi[:, None, :]))
|
|
|
ca_update = torch.mean(coherence, dim=2)
|
|
|
ca = ca * 0.95 + ca_update * 0.05
|
|
|
|
|
|
ca_gate = self.ca_gate(ca)
|
|
|
coupled = torch.matmul(phi, W)
|
|
|
ca_gate_full = torch.cat([ca_gate, ca_gate], dim=1)
|
|
|
features = coupled * ca_gate_full
|
|
|
features = self.norm(features + phi)
|
|
|
|
|
|
return features, ca, W
|
|
|
|
|
|
|
|
|
class TemporalConstraintEEGModel(nn.Module):
|
|
|
def __init__(self, n_channels=64, num_classes=len(TARGET_CATEGORIES)):
|
|
|
super().__init__()
|
|
|
self.n_channels = n_channels
|
|
|
|
|
|
self.window_encoders = nn.ModuleList([
|
|
|
self._build_cnn_encoder() for _ in TIME_WINDOWS
|
|
|
])
|
|
|
|
|
|
self.ca_modules = nn.ModuleList([
|
|
|
CalciumAttentionModule(256, d_model=256) for _ in TIME_WINDOWS
|
|
|
])
|
|
|
|
|
|
self.classifier = nn.Sequential(
|
|
|
nn.Linear(256 * len(TIME_WINDOWS), 512),
|
|
|
nn.BatchNorm1d(512),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(0.3),
|
|
|
nn.Linear(512, num_classes)
|
|
|
)
|
|
|
|
|
|
def _build_cnn_encoder(self):
|
|
|
return nn.Sequential(
|
|
|
nn.Conv1d(self.n_channels, 128, kernel_size=15, padding=7),
|
|
|
nn.BatchNorm1d(128),
|
|
|
nn.ELU(),
|
|
|
nn.MaxPool1d(2),
|
|
|
nn.Conv1d(128, 256, kernel_size=7, padding=3),
|
|
|
nn.BatchNorm1d(256),
|
|
|
nn.ELU(),
|
|
|
nn.AdaptiveAvgPool1d(1)
|
|
|
)
|
|
|
|
|
|
def forward(self, eeg_windows, return_intermediates=False):
|
|
|
batch_size = eeg_windows[0].size(0)
|
|
|
|
|
|
window_features, ca_history, W_history, window_logits_list = [], [], [], []
|
|
|
ca_state, W_state = None, None
|
|
|
|
|
|
for i, (encoder, ca_module, eeg_window) in enumerate(
|
|
|
zip(self.window_encoders, self.ca_modules, eeg_windows)
|
|
|
):
|
|
|
cnn_features = encoder(eeg_window).squeeze(-1)
|
|
|
features, ca_state, W_state = ca_module(cnn_features, ca_state, W_state)
|
|
|
|
|
|
window_features.append(features)
|
|
|
if return_intermediates:
|
|
|
ca_history.append(ca_state.detach())
|
|
|
W_history.append(W_state.detach())
|
|
|
|
|
|
padded_features = window_features + [
|
|
|
torch.zeros_like(features) for _ in range(len(TIME_WINDOWS) - len(window_features))
|
|
|
]
|
|
|
intermediate_logits = self.classifier(torch.cat(padded_features, dim=1))
|
|
|
window_logits_list.append(intermediate_logits.detach())
|
|
|
|
|
|
combined = torch.cat(window_features, dim=1)
|
|
|
logits = self.classifier(combined)
|
|
|
|
|
|
if return_intermediates:
|
|
|
return logits, ca_history, W_history, window_logits_list
|
|
|
return logits, ca_history
|
|
|
|
|
|
|
|
|
|
|
|
class FilteredTestDataset:
|
|
|
def __init__(self, annotations_path, max_samples=1000):
|
|
|
print("Loading and filtering test dataset...")
|
|
|
self.eeg_dataset = load_dataset("Alljoined/05_125", split='test', streaming=False).select(range(max_samples))
|
|
|
with open(annotations_path, 'r') as f:
|
|
|
coco_data = json.load(f)
|
|
|
|
|
|
image_annotations = defaultdict(set)
|
|
|
for ann in coco_data['annotations']:
|
|
|
image_annotations[ann['image_id']].add(ann['category_id'])
|
|
|
|
|
|
self.filtered_samples = []
|
|
|
for idx, sample in enumerate(self.eeg_dataset):
|
|
|
ann_ids = image_annotations.get(sample['coco_id'], set())
|
|
|
if not any(cat_id in EXCLUDED_IDS for cat_id in ann_ids) and any(cat_id in TARGET_IDS for cat_id in ann_ids):
|
|
|
self.filtered_samples.append({
|
|
|
'coco_id': sample['coco_id'],
|
|
|
'eeg_data': np.array(sample['EEG'], dtype=np.float32)
|
|
|
})
|
|
|
print(f"Loaded {len(self.filtered_samples)} filtered test samples.")
|
|
|
if not self.filtered_samples: raise RuntimeError("No suitable test samples found.")
|
|
|
|
|
|
def get_eeg_windows(self, sample_info):
|
|
|
eeg_data = sample_info['eeg_data']
|
|
|
eeg_windows = []
|
|
|
for start_ms, end_ms, _ in TIME_WINDOWS:
|
|
|
start_idx, end_idx = int(start_ms / 1000 * EEG_SAMPLE_RATE), int(end_ms / 1000 * EEG_SAMPLE_RATE)
|
|
|
n_timepoints = end_idx - start_idx
|
|
|
window = eeg_data[:, start_idx:end_idx] if eeg_data.shape[1] >= end_idx else eeg_data[:, start_idx:]
|
|
|
|
|
|
if window.shape[1] != n_timepoints:
|
|
|
pad_width = n_timepoints - window.shape[1]
|
|
|
window = np.pad(window, ((0,0), (0, pad_width)), 'edge') if pad_width > 0 else window[:, :n_timepoints]
|
|
|
|
|
|
window = (window - window.mean(axis=1, keepdims=True)) / (window.std(axis=1, keepdims=True) + 1e-8)
|
|
|
eeg_windows.append(window)
|
|
|
return eeg_windows
|
|
|
|
|
|
def get_random_sample_info(self):
|
|
|
return random.choice(self.filtered_samples)
|
|
|
|
|
|
|
|
|
class CalciumBridgeViewer(tk.Tk):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.title("Calcium-Bridge EEG Constraint Viewer V2 (Extended Window)")
|
|
|
self.geometry("2000x1000")
|
|
|
self.model, self.test_data = None, None
|
|
|
self.setup_gui()
|
|
|
|
|
|
def setup_gui(self):
|
|
|
control_frame = ttk.Frame(self); control_frame.pack(pady=10, padx=10, fill=tk.X)
|
|
|
ttk.Label(control_frame, text="COCO Path:").pack(side=tk.LEFT, padx=5)
|
|
|
self.coco_var = tk.StringVar(); ttk.Entry(control_frame, textvariable=self.coco_var, width=20).pack(side=tk.LEFT, padx=2)
|
|
|
ttk.Button(control_frame, text="Browse", command=self.browse_coco).pack(side=tk.LEFT, padx=5)
|
|
|
ttk.Label(control_frame, text="Annotations:").pack(side=tk.LEFT, padx=5)
|
|
|
self.ann_var = tk.StringVar(); ttk.Entry(control_frame, textvariable=self.ann_var, width=20).pack(side=tk.LEFT, padx=2)
|
|
|
ttk.Button(control_frame, text="Browse", command=self.browse_ann).pack(side=tk.LEFT, padx=5)
|
|
|
ttk.Button(control_frame, text="Load V2 Model", command=self.load_model).pack(side=tk.LEFT, padx=20)
|
|
|
self.test_btn = ttk.Button(control_frame, text="Test Random Sample", command=self.test_sample, state=tk.DISABLED); self.test_btn.pack(side=tk.LEFT, padx=5)
|
|
|
self.status_label = tk.Label(control_frame, text="Model: Not loaded", fg="gray"); self.status_label.pack(side=tk.LEFT, padx=20)
|
|
|
|
|
|
main_paned = ttk.PanedWindow(self, orient=tk.HORIZONTAL); main_paned.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
|
|
image_frame = ttk.Frame(main_paned, width=400); main_paned.add(image_frame, weight=0)
|
|
|
ttk.Label(image_frame, text="COCO Image", font=("Arial", 12, "bold")).pack(pady=5)
|
|
|
self.image_canvas = tk.Canvas(image_frame, width=400, height=400, bg='lightgray'); self.image_canvas.pack()
|
|
|
self.coco_id_label = ttk.Label(image_frame, text="COCO ID: N/A"); self.coco_id_label.pack(pady=5)
|
|
|
|
|
|
self.notebook = ttk.Notebook(main_paned); main_paned.add(self.notebook, weight=1)
|
|
|
self.create_tabs()
|
|
|
|
|
|
def create_tabs(self):
|
|
|
self.constraint_fig, self.constraint_canvas = self.create_tab("Constraint Satisfaction", "How predictions crystallize as constraints are satisfied")
|
|
|
self.calcium_fig, self.calcium_canvas = self.create_tab("Calcium Attention", "Calcium state evolution: What the model 'focuses on' at each stage")
|
|
|
self.eeg_fig, self.eeg_canvas = self.create_tab("EEG Heatmaps", "Raw EEG signals for each time window")
|
|
|
|
|
|
def create_tab(self, title, description):
|
|
|
tab = ttk.Frame(self.notebook); self.notebook.add(tab, text=title)
|
|
|
ttk.Label(tab, text=description, font=("Arial", 11)).pack(pady=5)
|
|
|
fig = plt.Figure()
|
|
|
canvas = FigureCanvasTkAgg(fig, tab); canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
|
|
|
return fig, canvas
|
|
|
|
|
|
def browse_coco(self):
|
|
|
path = filedialog.askdirectory(); self.coco_var.set(path); self.coco_path = path
|
|
|
|
|
|
def browse_ann(self):
|
|
|
path = filedialog.askopenfilename(filetypes=[("JSON", "*.json")]); self.ann_var.set(path); self.annotations_path = path
|
|
|
|
|
|
def load_model(self):
|
|
|
model_path = filedialog.askopenfilename(filetypes=[("PyTorch Model", "*.pth")], title="Select calcium_bridge_eeg_model_v2.pth")
|
|
|
if not model_path or not self.annotations_path: return
|
|
|
try:
|
|
|
checkpoint = torch.load(model_path, map_location=DEVICE)
|
|
|
self.model = TemporalConstraintEEGModel().to(DEVICE)
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
self.model.eval()
|
|
|
self.test_data = FilteredTestDataset(self.annotations_path)
|
|
|
self.status_label.config(text="Model: V2 Loaded ✓", fg="green")
|
|
|
self.test_btn.config(state=tk.NORMAL)
|
|
|
except Exception as e: messagebox.showerror("Error", f"Failed to load model:\n{e}"); print(traceback.format_exc())
|
|
|
|
|
|
def _fetch_image(self, coco_id):
|
|
|
formatted_id = f"{coco_id:012d}.jpg"
|
|
|
for s in ["train2017", "val2017", "test2017"]:
|
|
|
path = os.path.join(self.coco_path, s, formatted_id)
|
|
|
if os.path.exists(path): return Image.open(path).convert("RGB")
|
|
|
return None
|
|
|
|
|
|
def test_sample(self):
|
|
|
if not self.model: return
|
|
|
try:
|
|
|
sample_info = self.test_data.get_random_sample_info()
|
|
|
image = self._fetch_image(sample_info['coco_id'])
|
|
|
if image: self.display_image(image, sample_info['coco_id'])
|
|
|
|
|
|
eeg_windows_np = self.test_data.get_eeg_windows(sample_info)
|
|
|
eeg_windows = [torch.from_numpy(w).unsqueeze(0).to(DEVICE) for w in eeg_windows_np]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
logits, ca_history, _, window_logits = self.model(eeg_windows, return_intermediates=True)
|
|
|
|
|
|
self.visualize_constraint_satisfaction(window_logits, logits)
|
|
|
self.visualize_calcium_evolution(ca_history)
|
|
|
self.visualize_eeg_heatmaps(eeg_windows_np)
|
|
|
except Exception as e: messagebox.showerror("Error", f"Failed to process sample:\n{e}"); print(traceback.format_exc())
|
|
|
|
|
|
def display_image(self, image, coco_id):
|
|
|
ratio = min(400/image.width, 400/image.height)
|
|
|
resized = image.resize((int(image.width * ratio), int(image.height * ratio)), Image.LANCZOS)
|
|
|
self.pil_image_tk = ImageTk.PhotoImage(resized)
|
|
|
self.image_canvas.create_image(200, 200, image=self.pil_image_tk)
|
|
|
self.coco_id_label.config(text=f"COCO ID: {coco_id}")
|
|
|
|
|
|
def visualize_constraint_satisfaction(self, window_logits, final_logits):
|
|
|
self.constraint_fig.clear()
|
|
|
cat_list = list(TARGET_CATEGORIES.keys())
|
|
|
n_windows = len(window_logits)
|
|
|
final_probs = torch.sigmoid(final_logits).squeeze(0).cpu().numpy()
|
|
|
top_indices = np.argsort(final_probs)[::-1][:10]
|
|
|
axes = self.constraint_fig.subplots(1, n_windows + 1)
|
|
|
|
|
|
for i, (ax, wl) in enumerate(zip(axes[:-1], window_logits)):
|
|
|
probs = torch.sigmoid(wl).squeeze(0).cpu().numpy()[top_indices]
|
|
|
ax.barh([cat_list[idx] for idx in top_indices], probs, color='steelblue')
|
|
|
ax.set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
|
|
|
ax.set_xlim(0, 1); ax.invert_yaxis(); ax.tick_params(axis='y', labelsize=8)
|
|
|
|
|
|
axes[-1].barh([cat_list[idx] for idx in top_indices], final_probs[top_indices], color='darkgreen')
|
|
|
axes[-1].set_title("Final\n(Combined)", fontsize=10); axes[-1].set_xlim(0, 1); axes[-1].invert_yaxis(); axes[-1].tick_params(axis='y', labelsize=8)
|
|
|
self.constraint_fig.suptitle("Constraint Satisfaction: Predictions Crystallizing Over Time", fontsize=14, fontweight='bold')
|
|
|
self.constraint_fig.tight_layout(); self.constraint_canvas.draw()
|
|
|
|
|
|
def visualize_calcium_evolution(self, ca_history):
|
|
|
self.calcium_fig.clear()
|
|
|
n_windows = len(ca_history)
|
|
|
axes = self.calcium_fig.subplots(2, n_windows)
|
|
|
|
|
|
for i, ca_state in enumerate(ca_history):
|
|
|
ca_np = ca_state.squeeze(0).cpu().numpy()
|
|
|
top_20_idx = np.argsort(ca_np)[::-1][:20]
|
|
|
axes[0, i].plot(ca_np, 'r'); axes[0, i].fill_between(range(len(ca_np)), ca_np, color='r', alpha=0.3)
|
|
|
axes[0, i].set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
|
|
|
axes[1, i].barh([f"F{idx}" for idx in top_20_idx], ca_np[top_20_idx], color='darkred')
|
|
|
axes[1, i].invert_yaxis(); axes[1, i].tick_params(axis='y', labelsize=7)
|
|
|
self.calcium_fig.suptitle("Calcium Attention: What the Model Focuses On", fontsize=14, fontweight='bold')
|
|
|
self.calcium_fig.tight_layout(); self.calcium_canvas.draw()
|
|
|
|
|
|
def visualize_eeg_heatmaps(self, eeg_windows_np):
|
|
|
self.eeg_fig.clear()
|
|
|
n_windows = len(eeg_windows_np)
|
|
|
axes = self.eeg_fig.subplots(1, n_windows)
|
|
|
|
|
|
for i, (ax, eeg_data) in enumerate(zip(axes, eeg_windows_np)):
|
|
|
im = ax.imshow(eeg_data, aspect='auto', cmap='RdBu_r', vmin=-3, vmax=3)
|
|
|
ax.set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
|
|
|
if i == 0: ax.set_ylabel("Channel")
|
|
|
self.eeg_fig.colorbar(im, ax=ax)
|
|
|
self.eeg_fig.suptitle("Raw EEG Signals by Time Window", fontsize=14, fontweight='bold')
|
|
|
self.eeg_fig.tight_layout(); self.eeg_canvas.draw()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
app = CalciumBridgeViewer()
|
|
|
app.mainloop() |