File size: 9,489 Bytes
518db7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""
Spike-Timing Dependent Plasticity (STDP) learning rule implementation.

This module implements biologically-plausible STDP learning using Norse library's
spiking neural network components. STDP adjusts synaptic weights based on the
relative timing of pre- and post-synaptic spikes:
- Pre-before-post (Δt > 0): Strengthens connection (LTP - Long-Term Potentiation)
- Post-before-pre (Δt < 0): Weakens connection (LTD - Long-Term Depression)

Reference: Bi & Poo (1998) - Synaptic Modifications in Cultured Hippocampal Neurons
"""

import torch
import torch.nn as nn
from typing import Optional
import math

class STDPLearner(nn.Module):
    """
    Spike-Timing Dependent Plasticity learning mechanism.

    Implements the exponential STDP learning window from Bi & Poo (1998):

    ΔW(Δt) = A+ * exp(-Δt/τ+)  if Δt > 0  (pre before post - LTP)
    ΔW(Δt) = A- * exp(Δt/τ-)   if Δt < 0  (post before pre - LTD)

    Where:
    - Δt: Time difference between pre and post spikes (ms)
    - A+, A-: Maximum weight change amplitudes
    - τ+, τ-: Time constants for LTP and LTD windows

    Args:
        tau_plus: Time constant for LTP window (ms), default 20.0
        tau_minus: Time constant for LTD window (ms), default 20.0
        a_plus: LTP amplitude, default 0.005
        a_minus: LTD amplitude, default 0.00525 (slightly asymmetric per Bi & Poo)
        w_min: Minimum synaptic weight, default 0.0
        w_max: Maximum synaptic weight, default 1.0
        device: Device to use ('cuda' or 'cpu'), default 'cuda'

    Reference: Bi & Poo (1998) Figure 1 - Exponential STDP learning window
    """

    def __init__(
        self,
        tau_plus: float = 20.0,
        tau_minus: float = 20.0,
        a_plus: float = 0.005,
        a_minus: float = 0.00525,
        w_min: float = 0.0,
        w_max: float = 1.0,
        device: str = "cuda",
    ):
        super().__init__()
        self.device = torch.device(device)

        # STDP time constants (register as buffers so they move with model)
        self.register_buffer("tau_plus", torch.tensor(tau_plus, dtype=torch.float32))
        self.register_buffer("tau_minus", torch.tensor(tau_minus, dtype=torch.float32))

        # STDP amplitudes
        self.register_buffer("a_plus", torch.tensor(a_plus, dtype=torch.float32))
        self.register_buffer("a_minus", torch.tensor(a_minus, dtype=torch.float32))

        # Weight bounds
        self.register_buffer("w_min", torch.tensor(w_min, dtype=torch.float32))
        self.register_buffer("w_max", torch.tensor(w_max, dtype=torch.float32))

    def compute_weight_change(
        self, delta_t: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute weight change based on spike timing difference.

        Uses the Bi & Poo (1998) exponential STDP window:
        - Positive Δt (pre before post): LTP (strengthening)
        - Negative Δt (post before pre): LTD (weakening)

        Args:
            delta_t: Time difference between pre and post spikes (ms)
                    Shape: (batch, num_synapses) or (num_synapses,)
                    Positive values = pre before post
                    Negative values = post before pre

        Returns:
            dw: Weight change for each synapse
                Shape: same as delta_t
                Positive values = strengthen, negative = weaken
        """
        # Cast buffers to tensors for type safety
        a_plus: torch.Tensor = self.a_plus  # type: ignore[assignment]
        a_minus: torch.Tensor = self.a_minus  # type: ignore[assignment]
        tau_plus: torch.Tensor = self.tau_plus  # type: ignore[assignment]
        tau_minus: torch.Tensor = self.tau_minus  # type: ignore[assignment]

        # LTP: pre before post (Δt > 0)
        ltp_mask = delta_t > 0
        ltp_change = a_plus * torch.exp(-delta_t / tau_plus)

        # LTD: post before pre (Δt < 0)
        ltd_mask = delta_t < 0
        ltd_change = -a_minus * torch.exp(delta_t / tau_minus)

        # Combine LTP and LTD
        dw = torch.zeros_like(delta_t, device=self.device)
        dw = torch.where(ltp_mask, ltp_change, dw)
        dw = torch.where(ltd_mask, ltd_change, dw)

        return dw

    def apply_stdp(
        self,
        weights: torch.Tensor,
        pre_spike_times: torch.Tensor,
        post_spike_times: torch.Tensor,
        dt: float = 1.0,
    ) -> torch.Tensor:
        """
        Apply STDP weight updates based on spike timing.

        Args:
            weights: Current synaptic weights
                    Shape: (num_pre, num_post) or (batch, num_pre, num_post)
            pre_spike_times: Times of pre-synaptic spikes (ms)
                           Shape: (num_pre,) or (batch, num_pre)
                           Use -inf for neurons that didn't spike
            post_spike_times: Times of post-synaptic spikes (ms)
                            Shape: (num_post,) or (batch, num_post)
                            Use -inf for neurons that didn't spike
            dt: Timestep resolution (ms), default 1.0

        Returns:
            updated_weights: Weights after STDP update
                           Shape: same as input weights
                           Clamped to [w_min, w_max]
        """
        # Compute all pairwise spike time differences
        # Δt = t_post - t_pre (positive if pre before post)
        if pre_spike_times.dim() == 1:
            # Single batch: (num_pre,) x (num_post,) -> (num_pre, num_post)
            delta_t = post_spike_times.unsqueeze(0) - pre_spike_times.unsqueeze(1)
        else:
            # Batched: (batch, num_pre) x (batch, num_post) -> (batch, num_pre, num_post)
            delta_t = post_spike_times.unsqueeze(1) - pre_spike_times.unsqueeze(2)

        # Mask out pairs where either neuron didn't spike (spike_time = -inf)
        valid_pairs = torch.isfinite(delta_t)

        # Compute weight changes
        dw = self.compute_weight_change(delta_t)

        # Zero out invalid pairs
        dw = torch.where(valid_pairs, dw, torch.zeros_like(dw))

        # Update weights
        updated_weights = weights + dw

        # Clamp to bounds
        w_min: torch.Tensor = self.w_min  # type: ignore[assignment]
        w_max: torch.Tensor = self.w_max  # type: ignore[assignment]
        updated_weights = torch.clamp(updated_weights, w_min, w_max)

        return updated_weights

    def forward(
        self,
        weights: torch.Tensor,
        pre_spikes: torch.Tensor,
        post_spikes: torch.Tensor,
        time_window: int = 50,
    ) -> torch.Tensor:
        """
        Forward pass: Apply STDP learning to weights based on spike trains.

        This method processes binary spike trains and extracts spike timings
        to apply the STDP learning rule.

        Args:
            weights: Current synaptic weights (num_pre, num_post)
            pre_spikes: Binary spike train for pre-synaptic neurons
                       Shape: (time_steps, num_pre)
                       1 = spike, 0 = no spike
            post_spikes: Binary spike train for post-synaptic neurons
                        Shape: (time_steps, num_post)
                        1 = spike, 0 = no spike
            time_window: Maximum time window for STDP (ms), default 50

        Returns:
            updated_weights: Weights after STDP update (num_pre, num_post)
        """
        time_steps, num_pre = pre_spikes.shape
        num_post = post_spikes.shape[1]

        # Extract spike times (find when each neuron spiked)
        # Use -inf for neurons that didn't spike
        pre_spike_times = torch.full(
            (num_pre,), float('-inf'), device=self.device, dtype=torch.float32
        )
        post_spike_times = torch.full(
            (num_post,), float('-inf'), device=self.device, dtype=torch.float32
        )

        # Find most recent spike for each neuron (within time window)
        for t in range(max(0, time_steps - time_window), time_steps):
            # Update pre-synaptic spike times
            pre_spiked = pre_spikes[t] > 0
            pre_spike_times = torch.where(
                pre_spiked,
                torch.tensor(float(t), device=self.device),
                pre_spike_times
            )

            # Update post-synaptic spike times
            post_spiked = post_spikes[t] > 0
            post_spike_times = torch.where(
                post_spiked,
                torch.tensor(float(t), device=self.device),
                post_spike_times
            )

        # Apply STDP
        updated_weights = self.apply_stdp(
            weights, pre_spike_times, post_spike_times
        )

        return updated_weights

    def get_learning_window(
        self, time_range: tuple[float, float] = (-50.0, 50.0), num_points: int = 1000
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Generate the STDP learning window curve for visualization.

        Useful for validating that the implementation matches Bi & Poo (1998) Figure 1.

        Args:
            time_range: (min_time, max_time) in ms, default (-50, 50)
            num_points: Number of points to sample, default 1000

        Returns:
            delta_t: Time differences (ms), shape (num_points,)
            dw: Corresponding weight changes, shape (num_points,)
        """
        delta_t = torch.linspace(
            time_range[0], time_range[1], num_points, device=self.device
        )
        dw = self.compute_weight_change(delta_t)

        return delta_t, dw