icarus112's picture
Upload folder using huggingface_hub
518db7a verified
"""
Spike-Timing Dependent Plasticity (STDP) learning rule implementation.
This module implements biologically-plausible STDP learning using Norse library's
spiking neural network components. STDP adjusts synaptic weights based on the
relative timing of pre- and post-synaptic spikes:
- Pre-before-post (Δt > 0): Strengthens connection (LTP - Long-Term Potentiation)
- Post-before-pre (Δt < 0): Weakens connection (LTD - Long-Term Depression)
Reference: Bi & Poo (1998) - Synaptic Modifications in Cultured Hippocampal Neurons
"""
import torch
import torch.nn as nn
from typing import Optional
import math
class STDPLearner(nn.Module):
"""
Spike-Timing Dependent Plasticity learning mechanism.
Implements the exponential STDP learning window from Bi & Poo (1998):
ΔW(Δt) = A+ * exp(-Δt/τ+) if Δt > 0 (pre before post - LTP)
ΔW(Δt) = A- * exp(Δt/τ-) if Δt < 0 (post before pre - LTD)
Where:
- Δt: Time difference between pre and post spikes (ms)
- A+, A-: Maximum weight change amplitudes
- τ+, τ-: Time constants for LTP and LTD windows
Args:
tau_plus: Time constant for LTP window (ms), default 20.0
tau_minus: Time constant for LTD window (ms), default 20.0
a_plus: LTP amplitude, default 0.005
a_minus: LTD amplitude, default 0.00525 (slightly asymmetric per Bi & Poo)
w_min: Minimum synaptic weight, default 0.0
w_max: Maximum synaptic weight, default 1.0
device: Device to use ('cuda' or 'cpu'), default 'cuda'
Reference: Bi & Poo (1998) Figure 1 - Exponential STDP learning window
"""
def __init__(
self,
tau_plus: float = 20.0,
tau_minus: float = 20.0,
a_plus: float = 0.005,
a_minus: float = 0.00525,
w_min: float = 0.0,
w_max: float = 1.0,
device: str = "cuda",
):
super().__init__()
self.device = torch.device(device)
# STDP time constants (register as buffers so they move with model)
self.register_buffer("tau_plus", torch.tensor(tau_plus, dtype=torch.float32))
self.register_buffer("tau_minus", torch.tensor(tau_minus, dtype=torch.float32))
# STDP amplitudes
self.register_buffer("a_plus", torch.tensor(a_plus, dtype=torch.float32))
self.register_buffer("a_minus", torch.tensor(a_minus, dtype=torch.float32))
# Weight bounds
self.register_buffer("w_min", torch.tensor(w_min, dtype=torch.float32))
self.register_buffer("w_max", torch.tensor(w_max, dtype=torch.float32))
def compute_weight_change(
self, delta_t: torch.Tensor
) -> torch.Tensor:
"""
Compute weight change based on spike timing difference.
Uses the Bi & Poo (1998) exponential STDP window:
- Positive Δt (pre before post): LTP (strengthening)
- Negative Δt (post before pre): LTD (weakening)
Args:
delta_t: Time difference between pre and post spikes (ms)
Shape: (batch, num_synapses) or (num_synapses,)
Positive values = pre before post
Negative values = post before pre
Returns:
dw: Weight change for each synapse
Shape: same as delta_t
Positive values = strengthen, negative = weaken
"""
# Cast buffers to tensors for type safety
a_plus: torch.Tensor = self.a_plus # type: ignore[assignment]
a_minus: torch.Tensor = self.a_minus # type: ignore[assignment]
tau_plus: torch.Tensor = self.tau_plus # type: ignore[assignment]
tau_minus: torch.Tensor = self.tau_minus # type: ignore[assignment]
# LTP: pre before post (Δt > 0)
ltp_mask = delta_t > 0
ltp_change = a_plus * torch.exp(-delta_t / tau_plus)
# LTD: post before pre (Δt < 0)
ltd_mask = delta_t < 0
ltd_change = -a_minus * torch.exp(delta_t / tau_minus)
# Combine LTP and LTD
dw = torch.zeros_like(delta_t, device=self.device)
dw = torch.where(ltp_mask, ltp_change, dw)
dw = torch.where(ltd_mask, ltd_change, dw)
return dw
def apply_stdp(
self,
weights: torch.Tensor,
pre_spike_times: torch.Tensor,
post_spike_times: torch.Tensor,
dt: float = 1.0,
) -> torch.Tensor:
"""
Apply STDP weight updates based on spike timing.
Args:
weights: Current synaptic weights
Shape: (num_pre, num_post) or (batch, num_pre, num_post)
pre_spike_times: Times of pre-synaptic spikes (ms)
Shape: (num_pre,) or (batch, num_pre)
Use -inf for neurons that didn't spike
post_spike_times: Times of post-synaptic spikes (ms)
Shape: (num_post,) or (batch, num_post)
Use -inf for neurons that didn't spike
dt: Timestep resolution (ms), default 1.0
Returns:
updated_weights: Weights after STDP update
Shape: same as input weights
Clamped to [w_min, w_max]
"""
# Compute all pairwise spike time differences
# Δt = t_post - t_pre (positive if pre before post)
if pre_spike_times.dim() == 1:
# Single batch: (num_pre,) x (num_post,) -> (num_pre, num_post)
delta_t = post_spike_times.unsqueeze(0) - pre_spike_times.unsqueeze(1)
else:
# Batched: (batch, num_pre) x (batch, num_post) -> (batch, num_pre, num_post)
delta_t = post_spike_times.unsqueeze(1) - pre_spike_times.unsqueeze(2)
# Mask out pairs where either neuron didn't spike (spike_time = -inf)
valid_pairs = torch.isfinite(delta_t)
# Compute weight changes
dw = self.compute_weight_change(delta_t)
# Zero out invalid pairs
dw = torch.where(valid_pairs, dw, torch.zeros_like(dw))
# Update weights
updated_weights = weights + dw
# Clamp to bounds
w_min: torch.Tensor = self.w_min # type: ignore[assignment]
w_max: torch.Tensor = self.w_max # type: ignore[assignment]
updated_weights = torch.clamp(updated_weights, w_min, w_max)
return updated_weights
def forward(
self,
weights: torch.Tensor,
pre_spikes: torch.Tensor,
post_spikes: torch.Tensor,
time_window: int = 50,
) -> torch.Tensor:
"""
Forward pass: Apply STDP learning to weights based on spike trains.
This method processes binary spike trains and extracts spike timings
to apply the STDP learning rule.
Args:
weights: Current synaptic weights (num_pre, num_post)
pre_spikes: Binary spike train for pre-synaptic neurons
Shape: (time_steps, num_pre)
1 = spike, 0 = no spike
post_spikes: Binary spike train for post-synaptic neurons
Shape: (time_steps, num_post)
1 = spike, 0 = no spike
time_window: Maximum time window for STDP (ms), default 50
Returns:
updated_weights: Weights after STDP update (num_pre, num_post)
"""
time_steps, num_pre = pre_spikes.shape
num_post = post_spikes.shape[1]
# Extract spike times (find when each neuron spiked)
# Use -inf for neurons that didn't spike
pre_spike_times = torch.full(
(num_pre,), float('-inf'), device=self.device, dtype=torch.float32
)
post_spike_times = torch.full(
(num_post,), float('-inf'), device=self.device, dtype=torch.float32
)
# Find most recent spike for each neuron (within time window)
for t in range(max(0, time_steps - time_window), time_steps):
# Update pre-synaptic spike times
pre_spiked = pre_spikes[t] > 0
pre_spike_times = torch.where(
pre_spiked,
torch.tensor(float(t), device=self.device),
pre_spike_times
)
# Update post-synaptic spike times
post_spiked = post_spikes[t] > 0
post_spike_times = torch.where(
post_spiked,
torch.tensor(float(t), device=self.device),
post_spike_times
)
# Apply STDP
updated_weights = self.apply_stdp(
weights, pre_spike_times, post_spike_times
)
return updated_weights
def get_learning_window(
self, time_range: tuple[float, float] = (-50.0, 50.0), num_points: int = 1000
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Generate the STDP learning window curve for visualization.
Useful for validating that the implementation matches Bi & Poo (1998) Figure 1.
Args:
time_range: (min_time, max_time) in ms, default (-50, 50)
num_points: Number of points to sample, default 1000
Returns:
delta_t: Time differences (ms), shape (num_points,)
dw: Corresponding weight changes, shape (num_points,)
"""
delta_t = torch.linspace(
time_range[0], time_range[1], num_points, device=self.device
)
dw = self.compute_weight_change(delta_t)
return delta_t, dw