Spaces:
Runtime error
Runtime error
File size: 5,707 Bytes
297d76f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """
spike_decoder.py β Decode temporal spike trains into feature vectors.
Takes the rich temporal spike patterns from the encoder and extracts
meaningful features that a downstream classifier can use.
Three decoding strategies, each capturing different temporal information:
1. RateDecoder β baseline, ignores timing (same as Part 1)
2. TemporalDecoder β uses first spike times (when neurons fire)
3. SyncDecoder β uses synchrony patterns (which neurons fire together)
TBC's insight: timing carries MORE information than rate alone.
We prove this by comparing classifier accuracy across all three decoders.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class RateDecoder:
"""
Baseline decoder β ignores spike timing, just counts spikes.
This is equivalent to what Part 1 did with the reservoir.
Used as the control condition.
"""
def decode(self, spikes: np.ndarray) -> np.ndarray:
"""
Args:
spikes: (n_timesteps, n_neurons) binary array
Returns:
features: (n_neurons,) mean firing rate per neuron
"""
return spikes.mean(axis=0) # (n_neurons,)
def decode_batch(self, spike_batch: np.ndarray) -> np.ndarray:
"""Args: (N, T, n_neurons) β Returns: (N, n_neurons)"""
return spike_batch.mean(axis=1)
class TemporalDecoder:
"""
Temporal decoder β uses WHEN neurons fire, not just how often.
Features extracted:
1. First spike time β encodes intensity (bright β early)
2. Mean inter-spike interval β encodes regularity
3. Spike count in early window β captures onset response
4. Spike count in late window β captures sustained response
"""
def __init__(self, n_timesteps: int = 100, early_window: float = 0.3):
self.n_timesteps = n_timesteps
self.early_cutoff = int(n_timesteps * early_window)
def decode(self, spikes: np.ndarray) -> np.ndarray:
"""
Args:
spikes: (n_timesteps, n_neurons)
Returns:
features: (4 * n_neurons,) temporal feature vector
"""
n_timesteps, n_neurons = spikes.shape
# Feature 1: First spike time (normalized)
first_spike = np.full(n_neurons, 1.0)
for t in range(n_timesteps):
fired = spikes[t] > 0
never_fired = first_spike == 1.0
first_spike[fired & never_fired] = t / n_timesteps
# Feature 2: Total spike count (normalized)
spike_count = spikes.sum(axis=0) / n_timesteps
# Feature 3: Early window spike density
early_spikes = spikes[:self.early_cutoff].mean(axis=0)
# Feature 4: Late window spike density
late_spikes = spikes[self.early_cutoff:].mean(axis=0)
# Concatenate all temporal features
features = np.concatenate([
first_spike,
spike_count,
early_spikes,
late_spikes,
])
return features.astype(np.float32)
def decode_batch(self, spike_batch: np.ndarray) -> np.ndarray:
"""Args: (N, T, n_neurons) β Returns: (N, 4*n_neurons)"""
return np.array([self.decode(s) for s in spike_batch])
class SyncDecoder:
"""
Synchrony decoder β captures which neurons fire TOGETHER.
Synchronous firing between neurons is a key biological signal.
Two neurons firing at the same timestep = correlated activity = shared input.
This captures spatial relationships that rate coding misses.
Implementation: pairwise synchrony for neuron groups (not all pairs β too slow)
We group neurons into bins and measure within-bin synchrony.
"""
def __init__(self, n_timesteps: int = 100, n_groups: int = 16):
self.n_timesteps = n_timesteps
self.n_groups = n_groups
def decode(self, spikes: np.ndarray) -> np.ndarray:
"""
Args:
spikes: (n_timesteps, n_neurons)
Returns:
features: (n_neurons + n_groups,)
rate features + synchrony features per group
"""
n_timesteps, n_neurons = spikes.shape
group_size = n_neurons // self.n_groups
# Base rate features
rate_features = spikes.mean(axis=0)
# Synchrony features per group
sync_features = np.zeros(self.n_groups)
for g in range(self.n_groups):
start = g * group_size
end = start + group_size
group_spikes = spikes[:, start:end] # (T, group_size)
# Synchrony = mean pairwise coincidence within group
# Simplified: variance of group activity over time
group_activity = group_spikes.sum(axis=1) # (T,) total spikes per timestep
sync_features[g] = group_activity.std() / (group_activity.mean() + 1e-8)
features = np.concatenate([rate_features, sync_features])
return features.astype(np.float32)
def decode_batch(self, spike_batch: np.ndarray) -> np.ndarray:
return np.array([self.decode(s) for s in spike_batch])
class SpikeClassifier(nn.Module):
"""
Simple MLP classifier that works with any decoder's output.
Kept identical across all decoders to isolate the effect of decoding strategy.
"""
def __init__(self, input_dim: int, n_classes: int = 10):
super().__init__()
self.fc1 = nn.Linear(input_dim, 256)
self.fc2 = nn.Linear(256, 128)
self.classifier = nn.Linear(128, n_classes)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.dropout(x, p=0.3, training=self.training)
x = F.relu(self.fc2(x))
return self.classifier(x)
|