File size: 2,720 Bytes
440e322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
PLIFNode: D 维固定参数 PLIF 神经元(设计文档 5.5 "普通 SNN 神经元")

与 SelectivePLIFNode 的区别:
  SelectivePLIF: β(t), α(t), V_th(t) 由输入每步动态计算(选择性记忆)
  PLIFNode:      β, V_th 为 D 维可学习参数,训练后固定(信号转换)

每个维度有独立的可学习参数:
  β_d = sigmoid(w_d): 时间常数(衰减率)
  V_th_d: 发放阈值

动力学(与 ParametricLIF 一致):
  V[t] = β · V[t-1] + (1-β) · x[t]
  s[t] = Θ(V[t] - V_th)            (surrogate gradient)
  V[t] -= V_th · s[t]              (soft reset)
"""

import math

import torch
import torch.nn as nn
from spikingjelly.activation_based import base, surrogate


class PLIFNode(base.MemoryModule):
    """
    D 维固定参数 PLIF 神经元。

    Args:
        dim: 神经元数量(每个维度独立参数)
        init_tau: 初始时间常数 τ(β = 1 - 1/τ)
        v_threshold: 初始发放阈值
        surrogate_function: surrogate gradient 函数
    """

    def __init__(
        self,
        dim: int,
        init_tau: float = 2.0,
        v_threshold: float = 0.5,
        surrogate_function=surrogate.Sigmoid(alpha=4.0),
    ):
        super().__init__()
        # D 维可学习参数(随机初始化,每个维度独立)
        # w: 控制 β=sigmoid(w),随机产生不同时间常数
        #    init_w ± 0.5 → β ∈ ~[sigmoid(w-0.5), sigmoid(w+0.5)]
        #    tau=2.0 时 w=0, β ∈ ~[0.38, 0.62]
        init_w = -math.log(init_tau - 1.0)
        self.w = nn.Parameter(torch.empty(dim).normal_(init_w, 0.5))
        # v_th: 发放阈值,U[0.5x, 1.5x] 均匀分布产生维度间多样性
        self.v_th = nn.Parameter(torch.empty(dim).uniform_(
            v_threshold * 0.5, v_threshold * 1.5,
        ))
        self.surrogate_function = surrogate_function
        # 膜电位状态(functional.reset_net 时重置为 0.)
        self.register_memory('v', 0.)

    @property
    def beta(self):
        """D 维衰减率 β = sigmoid(w),值域 (0, 1)。"""
        return torch.sigmoid(self.w)

    def forward(self, x):
        """
        单步前向传播。

        V[t] = β · V[t-1] + (1-β) · x[t], spike = Θ(V-V_th), soft reset。

        Args:
            x: 输入电流, shape (batch, dim)

        Returns:
            spike: 二值脉冲, shape (batch, dim), 值域 {0, 1}
        """
        if isinstance(self.v, float):
            self.v = torch.zeros_like(x)
        beta = self.beta
        self.v = beta * self.v + (1.0 - beta) * x
        spike = self.surrogate_function(self.v - self.v_th)
        self.v = self.v - spike * self.v_th  # soft reset
        return spike