""" Spike-Timing Dependent Plasticity (STDP) learning rule implementation. This module implements biologically-plausible STDP learning using Norse library's spiking neural network components. STDP adjusts synaptic weights based on the relative timing of pre- and post-synaptic spikes: - Pre-before-post (Δt > 0): Strengthens connection (LTP - Long-Term Potentiation) - Post-before-pre (Δt < 0): Weakens connection (LTD - Long-Term Depression) Reference: Bi & Poo (1998) - Synaptic Modifications in Cultured Hippocampal Neurons """ import torch import torch.nn as nn from typing import Optional import math class STDPLearner(nn.Module): """ Spike-Timing Dependent Plasticity learning mechanism. Implements the exponential STDP learning window from Bi & Poo (1998): ΔW(Δt) = A+ * exp(-Δt/τ+) if Δt > 0 (pre before post - LTP) ΔW(Δt) = A- * exp(Δt/τ-) if Δt < 0 (post before pre - LTD) Where: - Δt: Time difference between pre and post spikes (ms) - A+, A-: Maximum weight change amplitudes - τ+, τ-: Time constants for LTP and LTD windows Args: tau_plus: Time constant for LTP window (ms), default 20.0 tau_minus: Time constant for LTD window (ms), default 20.0 a_plus: LTP amplitude, default 0.005 a_minus: LTD amplitude, default 0.00525 (slightly asymmetric per Bi & Poo) w_min: Minimum synaptic weight, default 0.0 w_max: Maximum synaptic weight, default 1.0 device: Device to use ('cuda' or 'cpu'), default 'cuda' Reference: Bi & Poo (1998) Figure 1 - Exponential STDP learning window """ def __init__( self, tau_plus: float = 20.0, tau_minus: float = 20.0, a_plus: float = 0.005, a_minus: float = 0.00525, w_min: float = 0.0, w_max: float = 1.0, device: str = "cuda", ): super().__init__() self.device = torch.device(device) # STDP time constants (register as buffers so they move with model) self.register_buffer("tau_plus", torch.tensor(tau_plus, dtype=torch.float32)) self.register_buffer("tau_minus", torch.tensor(tau_minus, dtype=torch.float32)) # STDP amplitudes self.register_buffer("a_plus", torch.tensor(a_plus, dtype=torch.float32)) self.register_buffer("a_minus", torch.tensor(a_minus, dtype=torch.float32)) # Weight bounds self.register_buffer("w_min", torch.tensor(w_min, dtype=torch.float32)) self.register_buffer("w_max", torch.tensor(w_max, dtype=torch.float32)) def compute_weight_change( self, delta_t: torch.Tensor ) -> torch.Tensor: """ Compute weight change based on spike timing difference. Uses the Bi & Poo (1998) exponential STDP window: - Positive Δt (pre before post): LTP (strengthening) - Negative Δt (post before pre): LTD (weakening) Args: delta_t: Time difference between pre and post spikes (ms) Shape: (batch, num_synapses) or (num_synapses,) Positive values = pre before post Negative values = post before pre Returns: dw: Weight change for each synapse Shape: same as delta_t Positive values = strengthen, negative = weaken """ # Cast buffers to tensors for type safety a_plus: torch.Tensor = self.a_plus # type: ignore[assignment] a_minus: torch.Tensor = self.a_minus # type: ignore[assignment] tau_plus: torch.Tensor = self.tau_plus # type: ignore[assignment] tau_minus: torch.Tensor = self.tau_minus # type: ignore[assignment] # LTP: pre before post (Δt > 0) ltp_mask = delta_t > 0 ltp_change = a_plus * torch.exp(-delta_t / tau_plus) # LTD: post before pre (Δt < 0) ltd_mask = delta_t < 0 ltd_change = -a_minus * torch.exp(delta_t / tau_minus) # Combine LTP and LTD dw = torch.zeros_like(delta_t, device=self.device) dw = torch.where(ltp_mask, ltp_change, dw) dw = torch.where(ltd_mask, ltd_change, dw) return dw def apply_stdp( self, weights: torch.Tensor, pre_spike_times: torch.Tensor, post_spike_times: torch.Tensor, dt: float = 1.0, ) -> torch.Tensor: """ Apply STDP weight updates based on spike timing. Args: weights: Current synaptic weights Shape: (num_pre, num_post) or (batch, num_pre, num_post) pre_spike_times: Times of pre-synaptic spikes (ms) Shape: (num_pre,) or (batch, num_pre) Use -inf for neurons that didn't spike post_spike_times: Times of post-synaptic spikes (ms) Shape: (num_post,) or (batch, num_post) Use -inf for neurons that didn't spike dt: Timestep resolution (ms), default 1.0 Returns: updated_weights: Weights after STDP update Shape: same as input weights Clamped to [w_min, w_max] """ # Compute all pairwise spike time differences # Δt = t_post - t_pre (positive if pre before post) if pre_spike_times.dim() == 1: # Single batch: (num_pre,) x (num_post,) -> (num_pre, num_post) delta_t = post_spike_times.unsqueeze(0) - pre_spike_times.unsqueeze(1) else: # Batched: (batch, num_pre) x (batch, num_post) -> (batch, num_pre, num_post) delta_t = post_spike_times.unsqueeze(1) - pre_spike_times.unsqueeze(2) # Mask out pairs where either neuron didn't spike (spike_time = -inf) valid_pairs = torch.isfinite(delta_t) # Compute weight changes dw = self.compute_weight_change(delta_t) # Zero out invalid pairs dw = torch.where(valid_pairs, dw, torch.zeros_like(dw)) # Update weights updated_weights = weights + dw # Clamp to bounds w_min: torch.Tensor = self.w_min # type: ignore[assignment] w_max: torch.Tensor = self.w_max # type: ignore[assignment] updated_weights = torch.clamp(updated_weights, w_min, w_max) return updated_weights def forward( self, weights: torch.Tensor, pre_spikes: torch.Tensor, post_spikes: torch.Tensor, time_window: int = 50, ) -> torch.Tensor: """ Forward pass: Apply STDP learning to weights based on spike trains. This method processes binary spike trains and extracts spike timings to apply the STDP learning rule. Args: weights: Current synaptic weights (num_pre, num_post) pre_spikes: Binary spike train for pre-synaptic neurons Shape: (time_steps, num_pre) 1 = spike, 0 = no spike post_spikes: Binary spike train for post-synaptic neurons Shape: (time_steps, num_post) 1 = spike, 0 = no spike time_window: Maximum time window for STDP (ms), default 50 Returns: updated_weights: Weights after STDP update (num_pre, num_post) """ time_steps, num_pre = pre_spikes.shape num_post = post_spikes.shape[1] # Extract spike times (find when each neuron spiked) # Use -inf for neurons that didn't spike pre_spike_times = torch.full( (num_pre,), float('-inf'), device=self.device, dtype=torch.float32 ) post_spike_times = torch.full( (num_post,), float('-inf'), device=self.device, dtype=torch.float32 ) # Find most recent spike for each neuron (within time window) for t in range(max(0, time_steps - time_window), time_steps): # Update pre-synaptic spike times pre_spiked = pre_spikes[t] > 0 pre_spike_times = torch.where( pre_spiked, torch.tensor(float(t), device=self.device), pre_spike_times ) # Update post-synaptic spike times post_spiked = post_spikes[t] > 0 post_spike_times = torch.where( post_spiked, torch.tensor(float(t), device=self.device), post_spike_times ) # Apply STDP updated_weights = self.apply_stdp( weights, pre_spike_times, post_spike_times ) return updated_weights def get_learning_window( self, time_range: tuple[float, float] = (-50.0, 50.0), num_points: int = 1000 ) -> tuple[torch.Tensor, torch.Tensor]: """ Generate the STDP learning window curve for visualization. Useful for validating that the implementation matches Bi & Poo (1998) Figure 1. Args: time_range: (min_time, max_time) in ms, default (-50, 50) num_points: Number of points to sample, default 1000 Returns: delta_t: Time differences (ms), shape (num_points,) dw: Corresponding weight changes, shape (num_points,) """ delta_t = torch.linspace( time_range[0], time_range[1], num_points, device=self.device ) dw = self.compute_weight_change(delta_t) return delta_t, dw