neuroforge / spike_decoder.py
dlokesha
Part 3 results: sync coding 87.8% > rate 86.8% β€” synchrony patterns outperform rate coding
297d76f
"""
spike_decoder.py β€” Decode temporal spike trains into feature vectors.
Takes the rich temporal spike patterns from the encoder and extracts
meaningful features that a downstream classifier can use.
Three decoding strategies, each capturing different temporal information:
1. RateDecoder β€” baseline, ignores timing (same as Part 1)
2. TemporalDecoder β€” uses first spike times (when neurons fire)
3. SyncDecoder β€” uses synchrony patterns (which neurons fire together)
TBC's insight: timing carries MORE information than rate alone.
We prove this by comparing classifier accuracy across all three decoders.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class RateDecoder:
"""
Baseline decoder β€” ignores spike timing, just counts spikes.
This is equivalent to what Part 1 did with the reservoir.
Used as the control condition.
"""
def decode(self, spikes: np.ndarray) -> np.ndarray:
"""
Args:
spikes: (n_timesteps, n_neurons) binary array
Returns:
features: (n_neurons,) mean firing rate per neuron
"""
return spikes.mean(axis=0) # (n_neurons,)
def decode_batch(self, spike_batch: np.ndarray) -> np.ndarray:
"""Args: (N, T, n_neurons) β†’ Returns: (N, n_neurons)"""
return spike_batch.mean(axis=1)
class TemporalDecoder:
"""
Temporal decoder β€” uses WHEN neurons fire, not just how often.
Features extracted:
1. First spike time β€” encodes intensity (bright β†’ early)
2. Mean inter-spike interval β€” encodes regularity
3. Spike count in early window β€” captures onset response
4. Spike count in late window β€” captures sustained response
"""
def __init__(self, n_timesteps: int = 100, early_window: float = 0.3):
self.n_timesteps = n_timesteps
self.early_cutoff = int(n_timesteps * early_window)
def decode(self, spikes: np.ndarray) -> np.ndarray:
"""
Args:
spikes: (n_timesteps, n_neurons)
Returns:
features: (4 * n_neurons,) temporal feature vector
"""
n_timesteps, n_neurons = spikes.shape
# Feature 1: First spike time (normalized)
first_spike = np.full(n_neurons, 1.0)
for t in range(n_timesteps):
fired = spikes[t] > 0
never_fired = first_spike == 1.0
first_spike[fired & never_fired] = t / n_timesteps
# Feature 2: Total spike count (normalized)
spike_count = spikes.sum(axis=0) / n_timesteps
# Feature 3: Early window spike density
early_spikes = spikes[:self.early_cutoff].mean(axis=0)
# Feature 4: Late window spike density
late_spikes = spikes[self.early_cutoff:].mean(axis=0)
# Concatenate all temporal features
features = np.concatenate([
first_spike,
spike_count,
early_spikes,
late_spikes,
])
return features.astype(np.float32)
def decode_batch(self, spike_batch: np.ndarray) -> np.ndarray:
"""Args: (N, T, n_neurons) β†’ Returns: (N, 4*n_neurons)"""
return np.array([self.decode(s) for s in spike_batch])
class SyncDecoder:
"""
Synchrony decoder β€” captures which neurons fire TOGETHER.
Synchronous firing between neurons is a key biological signal.
Two neurons firing at the same timestep = correlated activity = shared input.
This captures spatial relationships that rate coding misses.
Implementation: pairwise synchrony for neuron groups (not all pairs β€” too slow)
We group neurons into bins and measure within-bin synchrony.
"""
def __init__(self, n_timesteps: int = 100, n_groups: int = 16):
self.n_timesteps = n_timesteps
self.n_groups = n_groups
def decode(self, spikes: np.ndarray) -> np.ndarray:
"""
Args:
spikes: (n_timesteps, n_neurons)
Returns:
features: (n_neurons + n_groups,)
rate features + synchrony features per group
"""
n_timesteps, n_neurons = spikes.shape
group_size = n_neurons // self.n_groups
# Base rate features
rate_features = spikes.mean(axis=0)
# Synchrony features per group
sync_features = np.zeros(self.n_groups)
for g in range(self.n_groups):
start = g * group_size
end = start + group_size
group_spikes = spikes[:, start:end] # (T, group_size)
# Synchrony = mean pairwise coincidence within group
# Simplified: variance of group activity over time
group_activity = group_spikes.sum(axis=1) # (T,) total spikes per timestep
sync_features[g] = group_activity.std() / (group_activity.mean() + 1e-8)
features = np.concatenate([rate_features, sync_features])
return features.astype(np.float32)
def decode_batch(self, spike_batch: np.ndarray) -> np.ndarray:
return np.array([self.decode(s) for s in spike_batch])
class SpikeClassifier(nn.Module):
"""
Simple MLP classifier that works with any decoder's output.
Kept identical across all decoders to isolate the effect of decoding strategy.
"""
def __init__(self, input_dim: int, n_classes: int = 10):
super().__init__()
self.fc1 = nn.Linear(input_dim, 256)
self.fc2 = nn.Linear(256, 128)
self.classifier = nn.Linear(128, n_classes)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.dropout(x, p=0.3, training=self.training)
x = F.relu(self.fc2(x))
return self.classifier(x)