Brain2nd's picture
Initial release: NeuronSpark-0.9B-Chat instruction-tuned SNN language model
440e322 verified
"""
SNNBlock: 完整的 SNN 隐状态空间 Block(并行化版本)
结构(每个 SNN 时间步):
spike_in {0,1}^D
├─→ W_in → I[t] ∈ R^{D*N}
├─→ W_β^(x) + b_β → σ → β(t)
├─→ W_α^(x) + b_α → softplus → α(t)
├─→ W_th^(x) + b_th → |·|+V_min → V_th(t)
├─→ W_gate → sigmoid → gate ∈ (0,1)^D
└─→ W_skip → I_skip ∈ R^D
SelectivePLIF(I, β, α, V_th) → s[t] ∈ {0,1}^{D*N}
W_out · V_post[t] ⊙ gate + I_skip → 连续输出 ∈ R^D
数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import base, layer, surrogate
from .selective_plif import SelectivePLIFNode
from .parallel_scan import plif_parallel_forward
# ====== Fused modulation activations (torch.compile) ======
# Fuse sigmoid + softplus + abs + alpha*I into single kernel.
# 7-8 separate element-wise kernels → 1 fused kernel, ~4x speedup on DN-sized tensors.
# First call triggers JIT compilation (~seconds); cached for subsequent calls.
@torch.compile(backend='inductor', fullgraph=True)
def _fused_modulation(raw_beta, b_beta, raw_alpha, b_alpha, raw_th, b_th, v_th_min, I_all):
beta = torch.sigmoid(raw_beta + b_beta)
alpha = F.softplus(raw_alpha + b_alpha)
v_th = v_th_min + torch.abs(raw_th + b_th)
u = alpha * I_all
return beta, u, v_th
class SNNBlock(base.MemoryModule):
"""
单个 SNN Block(并行化)。
Args:
D: 可见维度(Block 间通信的维度)
N: 状态扩展因子(每个通道的隐神经元数)
v_th_min: 动态阈值下限
surrogate_function: surrogate gradient 函数
"""
def __init__(
self,
D: int,
N: int = 8,
v_th_min: float = 0.1,
surrogate_function=surrogate.Sigmoid(alpha=4.0),
):
super().__init__()
self.D = D
self.N = N
self.v_th_min = v_th_min
DN = D * N
# ====== 六条并行输入投影(SNN 突触:spike 输入) ======
self.W_in = layer.Linear(D, DN, bias=False, step_mode='s')
self.W_beta_x = layer.Linear(D, DN, bias=False, step_mode='s')
self.W_alpha_x = layer.Linear(D, DN, bias=False, step_mode='s')
self.W_th_x = layer.Linear(D, DN, bias=False, step_mode='s')
self.W_gate = layer.Linear(D, D, bias=False, step_mode='s')
self.W_skip = layer.Linear(D, D, bias=False, step_mode='s')
# ====== β/α/V_th 仅依赖 spike_in(无 W^(V)·V 项) ======
# ====== 调制偏置(结构化初始化) ======
self.b_beta = nn.Parameter(torch.empty(DN))
self.b_alpha = nn.Parameter(torch.empty(DN))
self.b_th = nn.Parameter(torch.empty(DN))
# ====== 输出投影:D*N → D(SNN 突触) ======
self.W_out = layer.Linear(DN, D, bias=False, step_mode='s')
# ====== 隐状态空间神经元(D*N 个,动态参数) ======
self.hidden_neuron = SelectivePLIFNode(
surrogate_function=surrogate_function,
detach_reset=False,
)
# ====== 参数初始化 ======
self._initialize_parameters()
def _initialize_parameters(self):
"""功能引导初始化。"""
D, N = self.D, self.N
K_ref = 16
# 目标 β 分布:多时间尺度 [0.80, 0.99]
beta_values = torch.linspace(0.80, 0.99, N)
# ====== 1. β 偏置:logit-spaced + 维度间随机扰动 ======
b_beta_per_n = torch.log(beta_values / (1.0 - beta_values))
# 以 per_n 值为均值,加 N(0, 0.1) 扰动打破 D 个通道的对称性
self.b_beta.data.copy_(b_beta_per_n.repeat(D))
self.b_beta.data.add_(torch.empty_like(self.b_beta).normal_(0, 0.1))
# ====== 2. α 偏置:softplus(0.5413) ≈ 1.0 + 维度间随机扰动 ======
# 以 0.5413 为均值,N(0, 0.1) 扰动 → α ∈ ~[0.7, 1.3]
self.b_alpha.data.normal_(0.5413, 0.1)
# ====== 3. W^(x) 权重 ======
for lin in [self.W_in, self.W_gate, self.W_skip, self.W_out]:
nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
for lin in [self.W_beta_x, self.W_alpha_x, self.W_th_x]:
nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
lin.weight.data.mul_(0.1)
# ====== 4. W_in 时间尺度缩放 ======
scale_per_n = torch.sqrt(1.0 - beta_values ** 2) # (N,)
scale_DN = scale_per_n.repeat(D) # (D*N,)
with torch.no_grad():
self.W_in.weight.mul_(scale_DN.unsqueeze(1))
# ====== 5. b_th:σ_V 校准 ======
# σ_V = sqrt(p/3) * sqrt(1 - β^{2K})
# 其中 p 是输入 firing rate。旧版假设 p=0.5(σ_I=0.408),
# 但实际 input_neuron firing rate 约 0.07~0.45,深层更低。
# 用 p=0.15 保守估计,避免 v_th 过高导致死神经元。
p_assumed = 0.15
sigma_I_base = math.sqrt(p_assumed / 3.0)
sigma_V_per_n = sigma_I_base * torch.sqrt(
1.0 - beta_values ** (2 * K_ref)
)
target_p_fire = torch.linspace(0.25, 0.08, N)
z_scores = math.sqrt(2.0) * torch.erfinv(
2.0 * (1.0 - target_p_fire) - 1.0
)
target_V_th = sigma_V_per_n * z_scores
b_th_per_n = torch.clamp(target_V_th - self.v_th_min, min=0.05)
# 以 per_n 值为均值,加 N(0, 0.02) 扰动打破 D 个通道的对称性
self.b_th.data.copy_(b_th_per_n.repeat(D))
self.b_th.data.add_(torch.empty_like(self.b_th).normal_(0, 0.02))
# ====== 6. W_out 发放率均衡缩放 ======
out_scale_per_n = 1.0 / torch.sqrt(target_p_fire)
out_scale_per_n = out_scale_per_n / out_scale_per_n.mean()
out_scale_DN = out_scale_per_n.repeat(D)
with torch.no_grad():
self.W_out.weight.mul_(out_scale_DN.unsqueeze(0))
def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor:
"""
并行前向传播:使用 parallel scan 处理全序列。
Args:
spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike
Returns:
continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出(V_post 经 W_out 投影)
"""
TK, batch, D = spike_in_seq.shape
DN = self.D * self.N
# ====== Phase 1: 批量投影(全部 TK 帧同时计算)======
flat = spike_in_seq.reshape(TK * batch, D)
I_all = F.linear(flat, self.W_in.weight).reshape(TK, batch, DN)
raw_beta = F.linear(flat, self.W_beta_x.weight).reshape(TK, batch, DN)
raw_alpha = F.linear(flat, self.W_alpha_x.weight).reshape(TK, batch, DN)
raw_th = F.linear(flat, self.W_th_x.weight).reshape(TK, batch, DN)
gate_all = torch.sigmoid(
F.linear(flat, self.W_gate.weight).reshape(TK, batch, D)
)
I_skip_all = F.linear(flat, self.W_skip.weight).reshape(TK, batch, D)
# ====== Phase 1b: 融合激活(torch.compile → 单 kernel)======
beta_all, u_hidden, v_th_all = _fused_modulation(
raw_beta, self.b_beta, raw_alpha, self.b_alpha,
raw_th, self.b_th, self.v_th_min, I_all,
)
# 获取隐神经元初始状态
v_init_hidden = self.hidden_neuron.v
if isinstance(v_init_hidden, float):
v_init_hidden = torch.zeros(batch, DN, device=flat.device, dtype=flat.dtype)
s_hidden, V_post_hidden, _ = plif_parallel_forward(
beta_all, u_hidden, v_th_all, v_init_hidden, max_iter=3,
surrogate_function=self.hidden_neuron.surrogate_function,
)
# 更新隐神经元状态(保存末步供下次调用)
self.hidden_neuron.v = V_post_hidden[-1].detach()
# ====== Phase 4: 输出投影(V_post → W_out: 连续梯度直通 β)======
# 用 V_post(膜电压)代替 spike 作为 W_out 输入,消除 surrogate 梯度瓶颈:
# spike 路径: ∂spike/∂β = surrogate'(V-v_th) · V_prev ≈ 0(大部分时刻)
# V_post 路径: ∂V_post/∂β = V_prev(无 surrogate 阻断,每步都有梯度)
v_flat = V_post_hidden.reshape(TK * batch, DN)
I_out_all = F.linear(v_flat, self.W_out.weight).reshape(TK, batch, D)
I_total_all = I_out_all * gate_all + I_skip_all # (TK, batch, D)
# output_neuron 已移除:连续值由层级 K 帧聚合处理
return I_total_all # (TK, batch, D), 连续值
def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor:
"""
单步前向传播(用于调试/兼容)。
Args:
spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1}
Returns:
continuous_out: 连续输出, shape (batch, D)
"""
V_prev = self.hidden_neuron.v
if isinstance(V_prev, float):
V_prev = torch.zeros(
spike_in.shape[0], self.D * self.N,
device=spike_in.device, dtype=spike_in.dtype,
)
I_t = self.W_in(spike_in)
# β 调制仅依赖 spike_in
beta = torch.sigmoid(self.W_beta_x(spike_in) + self.b_beta)
alpha = F.softplus(self.W_alpha_x(spike_in) + self.b_alpha)
v_th = self.v_th_min + torch.abs(self.W_th_x(spike_in) + self.b_th)
gate = torch.sigmoid(self.W_gate(spike_in))
I_skip = self.W_skip(spike_in)
s_hidden = self.hidden_neuron(I_t, beta, alpha, v_th)
# 用 V_post(膜电压)做输出投影,与 forward_parallel 一致
V_post = self.hidden_neuron.v # 发放+重置后的膜电位
I_out = self.W_out(V_post)
I_total = I_out * gate + I_skip
return I_total # 连续值