File size: 2,720 Bytes
440e322 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | """
PLIFNode: D 维固定参数 PLIF 神经元(设计文档 5.5 "普通 SNN 神经元")
与 SelectivePLIFNode 的区别:
SelectivePLIF: β(t), α(t), V_th(t) 由输入每步动态计算(选择性记忆)
PLIFNode: β, V_th 为 D 维可学习参数,训练后固定(信号转换)
每个维度有独立的可学习参数:
β_d = sigmoid(w_d): 时间常数(衰减率)
V_th_d: 发放阈值
动力学(与 ParametricLIF 一致):
V[t] = β · V[t-1] + (1-β) · x[t]
s[t] = Θ(V[t] - V_th) (surrogate gradient)
V[t] -= V_th · s[t] (soft reset)
"""
import math
import torch
import torch.nn as nn
from spikingjelly.activation_based import base, surrogate
class PLIFNode(base.MemoryModule):
"""
D 维固定参数 PLIF 神经元。
Args:
dim: 神经元数量(每个维度独立参数)
init_tau: 初始时间常数 τ(β = 1 - 1/τ)
v_threshold: 初始发放阈值
surrogate_function: surrogate gradient 函数
"""
def __init__(
self,
dim: int,
init_tau: float = 2.0,
v_threshold: float = 0.5,
surrogate_function=surrogate.Sigmoid(alpha=4.0),
):
super().__init__()
# D 维可学习参数(随机初始化,每个维度独立)
# w: 控制 β=sigmoid(w),随机产生不同时间常数
# init_w ± 0.5 → β ∈ ~[sigmoid(w-0.5), sigmoid(w+0.5)]
# tau=2.0 时 w=0, β ∈ ~[0.38, 0.62]
init_w = -math.log(init_tau - 1.0)
self.w = nn.Parameter(torch.empty(dim).normal_(init_w, 0.5))
# v_th: 发放阈值,U[0.5x, 1.5x] 均匀分布产生维度间多样性
self.v_th = nn.Parameter(torch.empty(dim).uniform_(
v_threshold * 0.5, v_threshold * 1.5,
))
self.surrogate_function = surrogate_function
# 膜电位状态(functional.reset_net 时重置为 0.)
self.register_memory('v', 0.)
@property
def beta(self):
"""D 维衰减率 β = sigmoid(w),值域 (0, 1)。"""
return torch.sigmoid(self.w)
def forward(self, x):
"""
单步前向传播。
V[t] = β · V[t-1] + (1-β) · x[t], spike = Θ(V-V_th), soft reset。
Args:
x: 输入电流, shape (batch, dim)
Returns:
spike: 二值脉冲, shape (batch, dim), 值域 {0, 1}
"""
if isinstance(self.v, float):
self.v = torch.zeros_like(x)
beta = self.beta
self.v = beta * self.v + (1.0 - beta) * x
spike = self.surrogate_function(self.v - self.v_th)
self.v = self.v - spike * self.v_th # soft reset
return spike
|