| """ |
| Spike-Timing Dependent Plasticity (STDP) learning rule implementation. |
| |
| This module implements biologically-plausible STDP learning using Norse library's |
| spiking neural network components. STDP adjusts synaptic weights based on the |
| relative timing of pre- and post-synaptic spikes: |
| - Pre-before-post (Δt > 0): Strengthens connection (LTP - Long-Term Potentiation) |
| - Post-before-pre (Δt < 0): Weakens connection (LTD - Long-Term Depression) |
| |
| Reference: Bi & Poo (1998) - Synaptic Modifications in Cultured Hippocampal Neurons |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Optional |
| import math |
|
|
| class STDPLearner(nn.Module): |
| """ |
| Spike-Timing Dependent Plasticity learning mechanism. |
| |
| Implements the exponential STDP learning window from Bi & Poo (1998): |
| |
| ΔW(Δt) = A+ * exp(-Δt/τ+) if Δt > 0 (pre before post - LTP) |
| ΔW(Δt) = A- * exp(Δt/τ-) if Δt < 0 (post before pre - LTD) |
| |
| Where: |
| - Δt: Time difference between pre and post spikes (ms) |
| - A+, A-: Maximum weight change amplitudes |
| - τ+, τ-: Time constants for LTP and LTD windows |
| |
| Args: |
| tau_plus: Time constant for LTP window (ms), default 20.0 |
| tau_minus: Time constant for LTD window (ms), default 20.0 |
| a_plus: LTP amplitude, default 0.005 |
| a_minus: LTD amplitude, default 0.00525 (slightly asymmetric per Bi & Poo) |
| w_min: Minimum synaptic weight, default 0.0 |
| w_max: Maximum synaptic weight, default 1.0 |
| device: Device to use ('cuda' or 'cpu'), default 'cuda' |
| |
| Reference: Bi & Poo (1998) Figure 1 - Exponential STDP learning window |
| """ |
|
|
| def __init__( |
| self, |
| tau_plus: float = 20.0, |
| tau_minus: float = 20.0, |
| a_plus: float = 0.005, |
| a_minus: float = 0.00525, |
| w_min: float = 0.0, |
| w_max: float = 1.0, |
| device: str = "cuda", |
| ): |
| super().__init__() |
| self.device = torch.device(device) |
|
|
| |
| self.register_buffer("tau_plus", torch.tensor(tau_plus, dtype=torch.float32)) |
| self.register_buffer("tau_minus", torch.tensor(tau_minus, dtype=torch.float32)) |
|
|
| |
| self.register_buffer("a_plus", torch.tensor(a_plus, dtype=torch.float32)) |
| self.register_buffer("a_minus", torch.tensor(a_minus, dtype=torch.float32)) |
|
|
| |
| self.register_buffer("w_min", torch.tensor(w_min, dtype=torch.float32)) |
| self.register_buffer("w_max", torch.tensor(w_max, dtype=torch.float32)) |
|
|
| def compute_weight_change( |
| self, delta_t: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Compute weight change based on spike timing difference. |
| |
| Uses the Bi & Poo (1998) exponential STDP window: |
| - Positive Δt (pre before post): LTP (strengthening) |
| - Negative Δt (post before pre): LTD (weakening) |
| |
| Args: |
| delta_t: Time difference between pre and post spikes (ms) |
| Shape: (batch, num_synapses) or (num_synapses,) |
| Positive values = pre before post |
| Negative values = post before pre |
| |
| Returns: |
| dw: Weight change for each synapse |
| Shape: same as delta_t |
| Positive values = strengthen, negative = weaken |
| """ |
| |
| a_plus: torch.Tensor = self.a_plus |
| a_minus: torch.Tensor = self.a_minus |
| tau_plus: torch.Tensor = self.tau_plus |
| tau_minus: torch.Tensor = self.tau_minus |
|
|
| |
| ltp_mask = delta_t > 0 |
| ltp_change = a_plus * torch.exp(-delta_t / tau_plus) |
|
|
| |
| ltd_mask = delta_t < 0 |
| ltd_change = -a_minus * torch.exp(delta_t / tau_minus) |
|
|
| |
| dw = torch.zeros_like(delta_t, device=self.device) |
| dw = torch.where(ltp_mask, ltp_change, dw) |
| dw = torch.where(ltd_mask, ltd_change, dw) |
|
|
| return dw |
|
|
| def apply_stdp( |
| self, |
| weights: torch.Tensor, |
| pre_spike_times: torch.Tensor, |
| post_spike_times: torch.Tensor, |
| dt: float = 1.0, |
| ) -> torch.Tensor: |
| """ |
| Apply STDP weight updates based on spike timing. |
| |
| Args: |
| weights: Current synaptic weights |
| Shape: (num_pre, num_post) or (batch, num_pre, num_post) |
| pre_spike_times: Times of pre-synaptic spikes (ms) |
| Shape: (num_pre,) or (batch, num_pre) |
| Use -inf for neurons that didn't spike |
| post_spike_times: Times of post-synaptic spikes (ms) |
| Shape: (num_post,) or (batch, num_post) |
| Use -inf for neurons that didn't spike |
| dt: Timestep resolution (ms), default 1.0 |
| |
| Returns: |
| updated_weights: Weights after STDP update |
| Shape: same as input weights |
| Clamped to [w_min, w_max] |
| """ |
| |
| |
| if pre_spike_times.dim() == 1: |
| |
| delta_t = post_spike_times.unsqueeze(0) - pre_spike_times.unsqueeze(1) |
| else: |
| |
| delta_t = post_spike_times.unsqueeze(1) - pre_spike_times.unsqueeze(2) |
|
|
| |
| valid_pairs = torch.isfinite(delta_t) |
|
|
| |
| dw = self.compute_weight_change(delta_t) |
|
|
| |
| dw = torch.where(valid_pairs, dw, torch.zeros_like(dw)) |
|
|
| |
| updated_weights = weights + dw |
|
|
| |
| w_min: torch.Tensor = self.w_min |
| w_max: torch.Tensor = self.w_max |
| updated_weights = torch.clamp(updated_weights, w_min, w_max) |
|
|
| return updated_weights |
|
|
| def forward( |
| self, |
| weights: torch.Tensor, |
| pre_spikes: torch.Tensor, |
| post_spikes: torch.Tensor, |
| time_window: int = 50, |
| ) -> torch.Tensor: |
| """ |
| Forward pass: Apply STDP learning to weights based on spike trains. |
| |
| This method processes binary spike trains and extracts spike timings |
| to apply the STDP learning rule. |
| |
| Args: |
| weights: Current synaptic weights (num_pre, num_post) |
| pre_spikes: Binary spike train for pre-synaptic neurons |
| Shape: (time_steps, num_pre) |
| 1 = spike, 0 = no spike |
| post_spikes: Binary spike train for post-synaptic neurons |
| Shape: (time_steps, num_post) |
| 1 = spike, 0 = no spike |
| time_window: Maximum time window for STDP (ms), default 50 |
| |
| Returns: |
| updated_weights: Weights after STDP update (num_pre, num_post) |
| """ |
| time_steps, num_pre = pre_spikes.shape |
| num_post = post_spikes.shape[1] |
|
|
| |
| |
| pre_spike_times = torch.full( |
| (num_pre,), float('-inf'), device=self.device, dtype=torch.float32 |
| ) |
| post_spike_times = torch.full( |
| (num_post,), float('-inf'), device=self.device, dtype=torch.float32 |
| ) |
|
|
| |
| for t in range(max(0, time_steps - time_window), time_steps): |
| |
| pre_spiked = pre_spikes[t] > 0 |
| pre_spike_times = torch.where( |
| pre_spiked, |
| torch.tensor(float(t), device=self.device), |
| pre_spike_times |
| ) |
|
|
| |
| post_spiked = post_spikes[t] > 0 |
| post_spike_times = torch.where( |
| post_spiked, |
| torch.tensor(float(t), device=self.device), |
| post_spike_times |
| ) |
|
|
| |
| updated_weights = self.apply_stdp( |
| weights, pre_spike_times, post_spike_times |
| ) |
|
|
| return updated_weights |
|
|
| def get_learning_window( |
| self, time_range: tuple[float, float] = (-50.0, 50.0), num_points: int = 1000 |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Generate the STDP learning window curve for visualization. |
| |
| Useful for validating that the implementation matches Bi & Poo (1998) Figure 1. |
| |
| Args: |
| time_range: (min_time, max_time) in ms, default (-50, 50) |
| num_points: Number of points to sample, default 1000 |
| |
| Returns: |
| delta_t: Time differences (ms), shape (num_points,) |
| dw: Corresponding weight changes, shape (num_points,) |
| """ |
| delta_t = torch.linspace( |
| time_range[0], time_range[1], num_points, device=self.device |
| ) |
| dw = self.compute_weight_change(delta_t) |
|
|
| return delta_t, dw |
|
|