CalciumBridgeEEGConstraintViewer / pkas_cal_viewer_gemini2.py
Aluode's picture
Upload 3 files
258e525 verified
"""
Calcium-Bridge EEG Constraint Viewer (V2.1 - Fixed)
Visualizes how constraint satisfaction unfolds across four temporal windows up to 550ms.
Shows:
1. Original COCO image
2. EEG heatmaps for each of the 4 time windows
3. Calcium "attention" evolution (what the model focuses on at each stage)
4. Top predictions crystallizing across the 4 windows
V2.1 Fixes:
- Corrected 'figsize' argument placement during figure creation.
- Corrected colorbar creation to use the figure object directly, resolving warnings.
"""
import os
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageTk
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import json
from pathlib import Path
from collections import defaultdict
import random
try:
from datasets import load_dataset
except ImportError:
print("Missing datasets library.")
exit()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EEG_SAMPLE_RATE = 512
TIME_WINDOWS = [
(50, 150, "EarlyVisual"),
(150, 250, "MidFeature"),
(250, 350, "LateSemantic"),
(350, 550, "CognitiveEvaluation")
]
TARGET_CATEGORIES = {
'elephant': 22, 'giraffe': 25, 'bear': 23, 'zebra': 24,
'cow': 21, 'sheep': 20, 'horse': 19, 'dog': 18, 'cat': 17, 'bird': 16,
'airplane': 5, 'train': 7, 'boat': 9, 'bus': 6, 'truck': 8,
'motorcycle': 4, 'bicycle': 2, 'car': 3,
'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13,
'parking meter': 14, 'bench': 15,
}
CATEGORY_NAMES = {v: k for k, v in TARGET_CATEGORIES.items()}
TARGET_IDS = set(TARGET_CATEGORIES.values())
ALL_COCO_IDS = list(range(1, 91))
EXCLUDED_IDS = set(ALL_COCO_IDS) - TARGET_IDS
# === MODEL ARCHITECTURE (Must match V2 training code) ===
class CalciumAttentionModule(nn.Module):
def __init__(self, n_features, d_model=256):
super().__init__()
self.n_features = n_features
self.d_model = d_model
self.phase_proj = nn.Linear(n_features, d_model)
self.ca_gate = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.Sigmoid()
)
self.W = nn.Parameter(torch.randn(d_model, d_model) * 0.01)
self.norm = nn.LayerNorm(d_model)
def forward(self, x, prev_ca=None, prev_W=None):
batch_size = x.size(0)
phi = self.phase_proj(x)
if prev_ca is None:
ca = torch.zeros(batch_size, self.d_model, device=x.device)
else:
ca = prev_ca.clone()
W = self.W if prev_W is None else prev_W
coherence = torch.abs(torch.cos(phi[:, :, None] - phi[:, None, :]))
ca_update = torch.mean(coherence, dim=2)
ca = ca * 0.95 + ca_update * 0.05
ca_gate = self.ca_gate(ca)
coupled = torch.matmul(phi, W)
ca_gate_full = torch.cat([ca_gate, ca_gate], dim=1)
features = coupled * ca_gate_full
features = self.norm(features + phi)
return features, ca, W
class TemporalConstraintEEGModel(nn.Module):
def __init__(self, n_channels=64, num_classes=len(TARGET_CATEGORIES)):
super().__init__()
self.n_channels = n_channels
self.window_encoders = nn.ModuleList([
self._build_cnn_encoder() for _ in TIME_WINDOWS
])
self.ca_modules = nn.ModuleList([
CalciumAttentionModule(256, d_model=256) for _ in TIME_WINDOWS
])
self.classifier = nn.Sequential(
nn.Linear(256 * len(TIME_WINDOWS), 512),
nn.BatchNorm1d(512),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def _build_cnn_encoder(self):
return nn.Sequential(
nn.Conv1d(self.n_channels, 128, kernel_size=15, padding=7),
nn.BatchNorm1d(128),
nn.ELU(),
nn.MaxPool1d(2),
nn.Conv1d(128, 256, kernel_size=7, padding=3),
nn.BatchNorm1d(256),
nn.ELU(),
nn.AdaptiveAvgPool1d(1)
)
def forward(self, eeg_windows, return_intermediates=False):
batch_size = eeg_windows[0].size(0)
window_features, ca_history, W_history, window_logits_list = [], [], [], []
ca_state, W_state = None, None
for i, (encoder, ca_module, eeg_window) in enumerate(
zip(self.window_encoders, self.ca_modules, eeg_windows)
):
cnn_features = encoder(eeg_window).squeeze(-1)
features, ca_state, W_state = ca_module(cnn_features, ca_state, W_state)
window_features.append(features)
if return_intermediates:
ca_history.append(ca_state.detach())
W_history.append(W_state.detach())
padded_features = window_features + [
torch.zeros_like(features) for _ in range(len(TIME_WINDOWS) - len(window_features))
]
intermediate_logits = self.classifier(torch.cat(padded_features, dim=1))
window_logits_list.append(intermediate_logits.detach())
combined = torch.cat(window_features, dim=1)
logits = self.classifier(combined)
if return_intermediates:
return logits, ca_history, W_history, window_logits_list
return logits, ca_history
# === DATA LOADER ===
class FilteredTestDataset:
def __init__(self, annotations_path, max_samples=1000):
print("Loading and filtering test dataset...")
self.eeg_dataset = load_dataset("Alljoined/05_125", split='test', streaming=False).select(range(max_samples))
with open(annotations_path, 'r') as f:
coco_data = json.load(f)
image_annotations = defaultdict(set)
for ann in coco_data['annotations']:
image_annotations[ann['image_id']].add(ann['category_id'])
self.filtered_samples = []
for idx, sample in enumerate(self.eeg_dataset):
ann_ids = image_annotations.get(sample['coco_id'], set())
if not any(cat_id in EXCLUDED_IDS for cat_id in ann_ids) and any(cat_id in TARGET_IDS for cat_id in ann_ids):
self.filtered_samples.append({
'coco_id': sample['coco_id'],
'eeg_data': np.array(sample['EEG'], dtype=np.float32)
})
print(f"Loaded {len(self.filtered_samples)} filtered test samples.")
if not self.filtered_samples: raise RuntimeError("No suitable test samples found.")
def get_eeg_windows(self, sample_info):
eeg_data = sample_info['eeg_data']
eeg_windows = []
for start_ms, end_ms, _ in TIME_WINDOWS:
start_idx, end_idx = int(start_ms / 1000 * EEG_SAMPLE_RATE), int(end_ms / 1000 * EEG_SAMPLE_RATE)
n_timepoints = end_idx - start_idx
window = eeg_data[:, start_idx:end_idx] if eeg_data.shape[1] >= end_idx else eeg_data[:, start_idx:]
if window.shape[1] != n_timepoints:
pad_width = n_timepoints - window.shape[1]
window = np.pad(window, ((0,0), (0, pad_width)), 'edge') if pad_width > 0 else window[:, :n_timepoints]
window = (window - window.mean(axis=1, keepdims=True)) / (window.std(axis=1, keepdims=True) + 1e-8)
eeg_windows.append(window)
return eeg_windows
def get_random_sample_info(self):
return random.choice(self.filtered_samples)
# === VIEWER APPLICATION ===
class CalciumBridgeViewer(tk.Tk):
def __init__(self):
super().__init__()
self.title("Calcium-Bridge EEG Constraint Viewer V2 (Extended Window)")
self.geometry("2000x1000")
self.model, self.test_data = None, None
self.setup_gui()
def setup_gui(self):
control_frame = ttk.Frame(self); control_frame.pack(pady=10, padx=10, fill=tk.X)
ttk.Label(control_frame, text="COCO Path:").pack(side=tk.LEFT, padx=5)
self.coco_var = tk.StringVar(); ttk.Entry(control_frame, textvariable=self.coco_var, width=20).pack(side=tk.LEFT, padx=2)
ttk.Button(control_frame, text="Browse", command=self.browse_coco).pack(side=tk.LEFT, padx=5)
ttk.Label(control_frame, text="Annotations:").pack(side=tk.LEFT, padx=5)
self.ann_var = tk.StringVar(); ttk.Entry(control_frame, textvariable=self.ann_var, width=20).pack(side=tk.LEFT, padx=2)
ttk.Button(control_frame, text="Browse", command=self.browse_ann).pack(side=tk.LEFT, padx=5)
ttk.Button(control_frame, text="Load V2 Model", command=self.load_model).pack(side=tk.LEFT, padx=20)
self.test_btn = ttk.Button(control_frame, text="Test Random Sample", command=self.test_sample, state=tk.DISABLED); self.test_btn.pack(side=tk.LEFT, padx=5)
self.status_label = tk.Label(control_frame, text="Model: Not loaded", fg="gray"); self.status_label.pack(side=tk.LEFT, padx=20)
main_paned = ttk.PanedWindow(self, orient=tk.HORIZONTAL); main_paned.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
image_frame = ttk.Frame(main_paned, width=400); main_paned.add(image_frame, weight=0)
ttk.Label(image_frame, text="COCO Image", font=("Arial", 12, "bold")).pack(pady=5)
self.image_canvas = tk.Canvas(image_frame, width=400, height=400, bg='lightgray'); self.image_canvas.pack()
self.coco_id_label = ttk.Label(image_frame, text="COCO ID: N/A"); self.coco_id_label.pack(pady=5)
self.notebook = ttk.Notebook(main_paned); main_paned.add(self.notebook, weight=1)
self.create_tabs()
def create_tabs(self):
self.constraint_fig, self.constraint_canvas = self.create_tab("Constraint Satisfaction", "How predictions crystallize as constraints are satisfied")
self.calcium_fig, self.calcium_canvas = self.create_tab("Calcium Attention", "Calcium state evolution: What the model 'focuses on' at each stage")
self.eeg_fig, self.eeg_canvas = self.create_tab("EEG Heatmaps", "Raw EEG signals for each time window")
def create_tab(self, title, description):
tab = ttk.Frame(self.notebook); self.notebook.add(tab, text=title)
ttk.Label(tab, text=description, font=("Arial", 11)).pack(pady=5)
fig = plt.Figure()
canvas = FigureCanvasTkAgg(fig, tab); canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
return fig, canvas
def browse_coco(self):
path = filedialog.askdirectory(); self.coco_var.set(path); self.coco_path = path
def browse_ann(self):
path = filedialog.askopenfilename(filetypes=[("JSON", "*.json")]); self.ann_var.set(path); self.annotations_path = path
def load_model(self):
model_path = filedialog.askopenfilename(filetypes=[("PyTorch Model", "*.pth")], title="Select calcium_bridge_eeg_model_v2.pth")
if not model_path or not self.annotations_path: return
try:
checkpoint = torch.load(model_path, map_location=DEVICE)
self.model = TemporalConstraintEEGModel().to(DEVICE)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
self.test_data = FilteredTestDataset(self.annotations_path)
self.status_label.config(text="Model: V2 Loaded ✓", fg="green")
self.test_btn.config(state=tk.NORMAL)
except Exception as e: messagebox.showerror("Error", f"Failed to load model:\n{e}"); print(traceback.format_exc())
def _fetch_image(self, coco_id):
formatted_id = f"{coco_id:012d}.jpg"
for s in ["train2017", "val2017", "test2017"]:
path = os.path.join(self.coco_path, s, formatted_id)
if os.path.exists(path): return Image.open(path).convert("RGB")
return None
def test_sample(self):
if not self.model: return
try:
sample_info = self.test_data.get_random_sample_info()
image = self._fetch_image(sample_info['coco_id'])
if image: self.display_image(image, sample_info['coco_id'])
eeg_windows_np = self.test_data.get_eeg_windows(sample_info)
eeg_windows = [torch.from_numpy(w).unsqueeze(0).to(DEVICE) for w in eeg_windows_np]
with torch.no_grad():
logits, ca_history, _, window_logits = self.model(eeg_windows, return_intermediates=True)
self.visualize_constraint_satisfaction(window_logits, logits)
self.visualize_calcium_evolution(ca_history)
self.visualize_eeg_heatmaps(eeg_windows_np)
except Exception as e: messagebox.showerror("Error", f"Failed to process sample:\n{e}"); print(traceback.format_exc())
def display_image(self, image, coco_id):
ratio = min(400/image.width, 400/image.height)
resized = image.resize((int(image.width * ratio), int(image.height * ratio)), Image.LANCZOS)
self.pil_image_tk = ImageTk.PhotoImage(resized)
self.image_canvas.create_image(200, 200, image=self.pil_image_tk)
self.coco_id_label.config(text=f"COCO ID: {coco_id}")
def visualize_constraint_satisfaction(self, window_logits, final_logits):
self.constraint_fig.clear()
cat_list = list(TARGET_CATEGORIES.keys())
n_windows = len(window_logits)
final_probs = torch.sigmoid(final_logits).squeeze(0).cpu().numpy()
top_indices = np.argsort(final_probs)[::-1][:10]
axes = self.constraint_fig.subplots(1, n_windows + 1)
for i, (ax, wl) in enumerate(zip(axes[:-1], window_logits)):
probs = torch.sigmoid(wl).squeeze(0).cpu().numpy()[top_indices]
ax.barh([cat_list[idx] for idx in top_indices], probs, color='steelblue')
ax.set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
ax.set_xlim(0, 1); ax.invert_yaxis(); ax.tick_params(axis='y', labelsize=8)
axes[-1].barh([cat_list[idx] for idx in top_indices], final_probs[top_indices], color='darkgreen')
axes[-1].set_title("Final\n(Combined)", fontsize=10); axes[-1].set_xlim(0, 1); axes[-1].invert_yaxis(); axes[-1].tick_params(axis='y', labelsize=8)
self.constraint_fig.suptitle("Constraint Satisfaction: Predictions Crystallizing Over Time", fontsize=14, fontweight='bold')
self.constraint_fig.tight_layout(); self.constraint_canvas.draw()
def visualize_calcium_evolution(self, ca_history):
self.calcium_fig.clear()
n_windows = len(ca_history)
axes = self.calcium_fig.subplots(2, n_windows)
for i, ca_state in enumerate(ca_history):
ca_np = ca_state.squeeze(0).cpu().numpy()
top_20_idx = np.argsort(ca_np)[::-1][:20]
axes[0, i].plot(ca_np, 'r'); axes[0, i].fill_between(range(len(ca_np)), ca_np, color='r', alpha=0.3)
axes[0, i].set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
axes[1, i].barh([f"F{idx}" for idx in top_20_idx], ca_np[top_20_idx], color='darkred')
axes[1, i].invert_yaxis(); axes[1, i].tick_params(axis='y', labelsize=7)
self.calcium_fig.suptitle("Calcium Attention: What the Model Focuses On", fontsize=14, fontweight='bold')
self.calcium_fig.tight_layout(); self.calcium_canvas.draw()
def visualize_eeg_heatmaps(self, eeg_windows_np):
self.eeg_fig.clear()
n_windows = len(eeg_windows_np)
axes = self.eeg_fig.subplots(1, n_windows)
for i, (ax, eeg_data) in enumerate(zip(axes, eeg_windows_np)):
im = ax.imshow(eeg_data, aspect='auto', cmap='RdBu_r', vmin=-3, vmax=3)
ax.set_title(f"{TIME_WINDOWS[i][2]}\n({TIME_WINDOWS[i][0]}-{TIME_WINDOWS[i][1]}ms)", fontsize=10)
if i == 0: ax.set_ylabel("Channel")
self.eeg_fig.colorbar(im, ax=ax) # CORRECTED
self.eeg_fig.suptitle("Raw EEG Signals by Time Window", fontsize=14, fontweight='bold')
self.eeg_fig.tight_layout(); self.eeg_canvas.draw()
if __name__ == "__main__":
app = CalciumBridgeViewer()
app.mainloop()