File size: 9,489 Bytes
518db7a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 | """
Spike-Timing Dependent Plasticity (STDP) learning rule implementation.
This module implements biologically-plausible STDP learning using Norse library's
spiking neural network components. STDP adjusts synaptic weights based on the
relative timing of pre- and post-synaptic spikes:
- Pre-before-post (Δt > 0): Strengthens connection (LTP - Long-Term Potentiation)
- Post-before-pre (Δt < 0): Weakens connection (LTD - Long-Term Depression)
Reference: Bi & Poo (1998) - Synaptic Modifications in Cultured Hippocampal Neurons
"""
import torch
import torch.nn as nn
from typing import Optional
import math
class STDPLearner(nn.Module):
"""
Spike-Timing Dependent Plasticity learning mechanism.
Implements the exponential STDP learning window from Bi & Poo (1998):
ΔW(Δt) = A+ * exp(-Δt/τ+) if Δt > 0 (pre before post - LTP)
ΔW(Δt) = A- * exp(Δt/τ-) if Δt < 0 (post before pre - LTD)
Where:
- Δt: Time difference between pre and post spikes (ms)
- A+, A-: Maximum weight change amplitudes
- τ+, τ-: Time constants for LTP and LTD windows
Args:
tau_plus: Time constant for LTP window (ms), default 20.0
tau_minus: Time constant for LTD window (ms), default 20.0
a_plus: LTP amplitude, default 0.005
a_minus: LTD amplitude, default 0.00525 (slightly asymmetric per Bi & Poo)
w_min: Minimum synaptic weight, default 0.0
w_max: Maximum synaptic weight, default 1.0
device: Device to use ('cuda' or 'cpu'), default 'cuda'
Reference: Bi & Poo (1998) Figure 1 - Exponential STDP learning window
"""
def __init__(
self,
tau_plus: float = 20.0,
tau_minus: float = 20.0,
a_plus: float = 0.005,
a_minus: float = 0.00525,
w_min: float = 0.0,
w_max: float = 1.0,
device: str = "cuda",
):
super().__init__()
self.device = torch.device(device)
# STDP time constants (register as buffers so they move with model)
self.register_buffer("tau_plus", torch.tensor(tau_plus, dtype=torch.float32))
self.register_buffer("tau_minus", torch.tensor(tau_minus, dtype=torch.float32))
# STDP amplitudes
self.register_buffer("a_plus", torch.tensor(a_plus, dtype=torch.float32))
self.register_buffer("a_minus", torch.tensor(a_minus, dtype=torch.float32))
# Weight bounds
self.register_buffer("w_min", torch.tensor(w_min, dtype=torch.float32))
self.register_buffer("w_max", torch.tensor(w_max, dtype=torch.float32))
def compute_weight_change(
self, delta_t: torch.Tensor
) -> torch.Tensor:
"""
Compute weight change based on spike timing difference.
Uses the Bi & Poo (1998) exponential STDP window:
- Positive Δt (pre before post): LTP (strengthening)
- Negative Δt (post before pre): LTD (weakening)
Args:
delta_t: Time difference between pre and post spikes (ms)
Shape: (batch, num_synapses) or (num_synapses,)
Positive values = pre before post
Negative values = post before pre
Returns:
dw: Weight change for each synapse
Shape: same as delta_t
Positive values = strengthen, negative = weaken
"""
# Cast buffers to tensors for type safety
a_plus: torch.Tensor = self.a_plus # type: ignore[assignment]
a_minus: torch.Tensor = self.a_minus # type: ignore[assignment]
tau_plus: torch.Tensor = self.tau_plus # type: ignore[assignment]
tau_minus: torch.Tensor = self.tau_minus # type: ignore[assignment]
# LTP: pre before post (Δt > 0)
ltp_mask = delta_t > 0
ltp_change = a_plus * torch.exp(-delta_t / tau_plus)
# LTD: post before pre (Δt < 0)
ltd_mask = delta_t < 0
ltd_change = -a_minus * torch.exp(delta_t / tau_minus)
# Combine LTP and LTD
dw = torch.zeros_like(delta_t, device=self.device)
dw = torch.where(ltp_mask, ltp_change, dw)
dw = torch.where(ltd_mask, ltd_change, dw)
return dw
def apply_stdp(
self,
weights: torch.Tensor,
pre_spike_times: torch.Tensor,
post_spike_times: torch.Tensor,
dt: float = 1.0,
) -> torch.Tensor:
"""
Apply STDP weight updates based on spike timing.
Args:
weights: Current synaptic weights
Shape: (num_pre, num_post) or (batch, num_pre, num_post)
pre_spike_times: Times of pre-synaptic spikes (ms)
Shape: (num_pre,) or (batch, num_pre)
Use -inf for neurons that didn't spike
post_spike_times: Times of post-synaptic spikes (ms)
Shape: (num_post,) or (batch, num_post)
Use -inf for neurons that didn't spike
dt: Timestep resolution (ms), default 1.0
Returns:
updated_weights: Weights after STDP update
Shape: same as input weights
Clamped to [w_min, w_max]
"""
# Compute all pairwise spike time differences
# Δt = t_post - t_pre (positive if pre before post)
if pre_spike_times.dim() == 1:
# Single batch: (num_pre,) x (num_post,) -> (num_pre, num_post)
delta_t = post_spike_times.unsqueeze(0) - pre_spike_times.unsqueeze(1)
else:
# Batched: (batch, num_pre) x (batch, num_post) -> (batch, num_pre, num_post)
delta_t = post_spike_times.unsqueeze(1) - pre_spike_times.unsqueeze(2)
# Mask out pairs where either neuron didn't spike (spike_time = -inf)
valid_pairs = torch.isfinite(delta_t)
# Compute weight changes
dw = self.compute_weight_change(delta_t)
# Zero out invalid pairs
dw = torch.where(valid_pairs, dw, torch.zeros_like(dw))
# Update weights
updated_weights = weights + dw
# Clamp to bounds
w_min: torch.Tensor = self.w_min # type: ignore[assignment]
w_max: torch.Tensor = self.w_max # type: ignore[assignment]
updated_weights = torch.clamp(updated_weights, w_min, w_max)
return updated_weights
def forward(
self,
weights: torch.Tensor,
pre_spikes: torch.Tensor,
post_spikes: torch.Tensor,
time_window: int = 50,
) -> torch.Tensor:
"""
Forward pass: Apply STDP learning to weights based on spike trains.
This method processes binary spike trains and extracts spike timings
to apply the STDP learning rule.
Args:
weights: Current synaptic weights (num_pre, num_post)
pre_spikes: Binary spike train for pre-synaptic neurons
Shape: (time_steps, num_pre)
1 = spike, 0 = no spike
post_spikes: Binary spike train for post-synaptic neurons
Shape: (time_steps, num_post)
1 = spike, 0 = no spike
time_window: Maximum time window for STDP (ms), default 50
Returns:
updated_weights: Weights after STDP update (num_pre, num_post)
"""
time_steps, num_pre = pre_spikes.shape
num_post = post_spikes.shape[1]
# Extract spike times (find when each neuron spiked)
# Use -inf for neurons that didn't spike
pre_spike_times = torch.full(
(num_pre,), float('-inf'), device=self.device, dtype=torch.float32
)
post_spike_times = torch.full(
(num_post,), float('-inf'), device=self.device, dtype=torch.float32
)
# Find most recent spike for each neuron (within time window)
for t in range(max(0, time_steps - time_window), time_steps):
# Update pre-synaptic spike times
pre_spiked = pre_spikes[t] > 0
pre_spike_times = torch.where(
pre_spiked,
torch.tensor(float(t), device=self.device),
pre_spike_times
)
# Update post-synaptic spike times
post_spiked = post_spikes[t] > 0
post_spike_times = torch.where(
post_spiked,
torch.tensor(float(t), device=self.device),
post_spike_times
)
# Apply STDP
updated_weights = self.apply_stdp(
weights, pre_spike_times, post_spike_times
)
return updated_weights
def get_learning_window(
self, time_range: tuple[float, float] = (-50.0, 50.0), num_points: int = 1000
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Generate the STDP learning window curve for visualization.
Useful for validating that the implementation matches Bi & Poo (1998) Figure 1.
Args:
time_range: (min_time, max_time) in ms, default (-50, 50)
num_points: Number of points to sample, default 1000
Returns:
delta_t: Time differences (ms), shape (num_points,)
dw: Corresponding weight changes, shape (num_points,)
"""
delta_t = torch.linspace(
time_range[0], time_range[1], num_points, device=self.device
)
dw = self.compute_weight_change(delta_t)
return delta_t, dw
|