neuroforge / spike_encoder.py
dlokesha
Part 3 results: sync coding 87.8% > rate 86.8% β€” synchrony patterns outperform rate coding
297d76f
"""
spike_encoder.py β€” Convert images into temporal spike trains.
TBC Post 1: "Our results rely on simple spike-rate summaries which do not
capture the temporal structure known to be central to biological computation.
Ongoing work focuses on decoding spike timing and using time as an
explicit encoding dimension."
This is exactly what Part 3 builds.
Rate coding (Part 1): image β†’ fraction of time each unit was active
Temporal coding (Part 3): image β†’ WHEN each neuron fires (precise timing)
Why timing matters:
Real neurons communicate through spike timing, not just rate.
A neuron firing at t=2ms vs t=8ms carries different information
even if both fire once. The brain uses this β€” we replicate it here.
Poisson spike trains:
The standard model for biological spike generation.
Each pixel intensity becomes a firing RATE (spikes/second).
We then sample actual spike times from a Poisson process.
High intensity pixel β†’ high firing rate β†’ spikes come early and often
Low intensity pixel β†’ low firing rate β†’ spikes are rare and late
"""
import numpy as np
import torch
class PoissonEncoder:
"""
Encodes a static image into a temporal sequence of spike trains.
Each pixel β†’ a neuron with firing rate proportional to pixel intensity.
We simulate T timesteps and record when each neuron fires.
Output shape: (T, n_neurons) β€” binary matrix, 1 = spike, 0 = silence
"""
def __init__(
self,
n_timesteps: int = 100, # simulation duration (ms)
max_rate: float = 100.0, # max firing rate (Hz) for brightest pixel
dt: float = 0.001, # timestep size (seconds) = 1ms
seed: int = 42,
):
self.n_timesteps = n_timesteps
self.max_rate = max_rate
self.dt = dt
self.rng = np.random.RandomState(seed)
def encode(self, image: np.ndarray) -> np.ndarray:
"""
Convert a 2D image into a spike train matrix.
Args:
image: (H, W) grayscale image, values in [0, 1]
Returns:
spikes: (n_timesteps, H*W) binary array
spikes[t, i] = 1 means neuron i fired at timestep t
"""
# Flatten image to 1D β€” each pixel is one neuron
pixels = image.flatten() # (n_neurons,)
n_neurons = len(pixels)
# Convert pixel intensity to firing rate
# Bright pixel (1.0) β†’ max_rate Hz
# Dark pixel (0.0) β†’ 0 Hz
rates = pixels * self.max_rate # (n_neurons,) in Hz
# Probability of firing in each 1ms timestep
# P(spike in dt) = rate * dt
# e.g. 100Hz neuron: P = 100 * 0.001 = 0.1 (10% chance per ms)
spike_probs = rates * self.dt # (n_neurons,)
spike_probs = np.clip(spike_probs, 0, 1)
# Sample spike trains using Poisson process
# For each timestep, each neuron independently decides to fire
spikes = self.rng.rand(self.n_timesteps, n_neurons) < spike_probs
return spikes.astype(np.float32) # (n_timesteps, n_neurons)
def encode_batch(self, images: np.ndarray) -> np.ndarray:
"""Encode a batch of images. Returns (N, T, n_neurons)."""
return np.array([self.encode(img) for img in images])
def get_first_spike_times(self, spikes: np.ndarray) -> np.ndarray:
"""
Extract first spike time for each neuron β€” a key temporal feature.
Neurons that fire EARLY encode high-intensity pixels.
Neurons that never fire encode dark/silent pixels.
Args:
spikes: (n_timesteps, n_neurons) binary array
Returns:
first_spike_times: (n_neurons,) β€” timestep of first spike,
n_timesteps if neuron never fired
"""
n_timesteps, n_neurons = spikes.shape
first_times = np.full(n_neurons, n_timesteps, dtype=float)
for t in range(n_timesteps):
fired = spikes[t] > 0
# Only update neurons that haven't fired yet
never_fired = first_times == n_timesteps
first_times[fired & never_fired] = t
# Normalize to [0, 1]
first_times = first_times / n_timesteps
return first_times # (n_neurons,)
class TemporalEncoder:
"""
More sophisticated temporal encoder using rank-order coding.
Rank-order coding: neurons fire in order of their intensity.
Brightest pixel fires first, darkest fires last (or not at all).
This is extremely efficient β€” identity of image conveyed in first few spikes.
Discovered by Thorpe et al. (1996) β€” the brain uses this for rapid
visual processing (we recognize faces in ~150ms despite slow neurons).
"""
def __init__(self, n_timesteps: int = 100):
self.n_timesteps = n_timesteps
def encode(self, image: np.ndarray) -> np.ndarray:
"""
Encode image using rank-order coding.
Brightest pixel β†’ fires at t=0
Dimmest pixel β†’ fires at t=n_timesteps-1 (or never)
"""
pixels = image.flatten()
n_neurons = len(pixels)
# Rank pixels by intensity (brightest = rank 0)
ranks = np.argsort(np.argsort(-pixels)) # (n_neurons,)
# Convert rank to spike time
# Rank 0 (brightest) β†’ t=0, Rank N-1 (dimmest) β†’ t=T-1
spike_times = (ranks / n_neurons * self.n_timesteps).astype(int)
spike_times = np.clip(spike_times, 0, self.n_timesteps - 1)
# Build spike matrix
spikes = np.zeros((self.n_timesteps, n_neurons), dtype=np.float32)
for neuron_idx, t in enumerate(spike_times):
if pixels[neuron_idx] > 0.05: # Only fire if pixel is bright enough
spikes[t, neuron_idx] = 1.0
return spikes # (n_timesteps, n_neurons)
def encode_batch(self, images: np.ndarray) -> np.ndarray:
return np.array([self.encode(img) for img in images])