File size: 7,418 Bytes
46977a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | """
SNNFFN: SNN 等价的 Feed-Forward Network
对标 Qwen3MLP 的 SwiGLU 结构:
Qwen3 MLP: down_proj( SiLU(gate_proj(x)) * up_proj(x) )
SNN FFN: down_proj( gate_V_post * up_V_post ) + skip
膜电位门控(对标 SiLU gating):
gate/up 神经元完整 PLIF 动力学(积分+阈值+重置),
输出膜电位 V_post 做连续乘法门控,替代 binary AND 门。
信号流:
x → gate_proj → gate_neuron → V_post_gate
x → up_proj → up_neuron → V_post_up
V_post_gate × V_post_up → gated
down_proj(gated) + skip_proj(x) → 连续输出
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import base, layer, surrogate
from .plif_node import PLIFNode
from .parallel_scan import plif_rowparam_forward
class SNNFFN(base.MemoryModule):
"""
SNN 等价的 Feed-Forward Network。
Args:
D: 可见维度(输入/输出 spike 维度)
D_ff: 中间层维度(对标 Qwen3 intermediate_size)
output_v_threshold: 输出神经元阈值
num_layers: 总层数,用于 down_proj 缩放
layer_idx: 当前层索引
surrogate_function: surrogate gradient 函数
"""
def __init__(
self,
D: int,
D_ff: int,
output_v_threshold: float = 0.3,
num_layers: int = 1,
layer_idx: int = 0,
surrogate_function=surrogate.Sigmoid(alpha=4.0),
):
super().__init__()
self.D = D
self.D_ff = D_ff
# ====== 三条投影路径(对标 SwiGLU: gate_proj, up_proj, down_proj) ======
self.gate_proj = layer.Linear(D, D_ff, bias=False, step_mode='s')
self.up_proj = layer.Linear(D, D_ff, bias=False, step_mode='s')
self.down_proj = layer.Linear(D_ff, D, bias=False, step_mode='s')
# ====== 残差路径 ======
self.skip_proj = layer.Linear(D, D, bias=False, step_mode='s')
# ====== 神经元(D 维或 D_ff 维可学习 β 和 V_th) ======
# gate_neuron: 门控发放
self.gate_neuron = PLIFNode(
dim=D_ff,
init_tau=2.0,
v_threshold=output_v_threshold,
surrogate_function=surrogate_function,
)
# up_neuron: 值发放
self.up_neuron = PLIFNode(
dim=D_ff,
init_tau=2.0,
v_threshold=output_v_threshold,
surrogate_function=surrogate_function,
)
# ====== 参数初始化 ======
self._initialize_parameters(num_layers)
def _initialize_parameters(self, num_layers: int):
"""初始化投影权重。
- gate_proj, up_proj, skip_proj: Kaiming uniform
- down_proj: Kaiming uniform × 1/√(num_layers),防深层梯度爆炸
"""
for lin in [self.gate_proj, self.up_proj, self.skip_proj]:
nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
self.down_proj.weight.data.mul_(1.0 / math.sqrt(num_layers))
def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor:
"""
并行前向传播:使用 parallel scan 处理全序列。
优化:
- gate_proj + up_proj 合并为单次 matmul(2 launch → 1)
- gate + up PLIF scan: row-param kernel(无需 expand+contiguous beta/v_th)
- u_merged: 向量缩放替代 cat(1次 broadcast multiply 替代 2次 scale + 1次 cat)
Args:
spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike
Returns:
continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出
"""
TK, batch, D = spike_in_seq.shape
D_ff = self.D_ff
flat = spike_in_seq.reshape(TK * batch, D)
# ====== Phase 1: 批量投影(gate+up 合并为 1 次 matmul) ======
W_gate_up = torch.cat([self.gate_proj.weight, self.up_proj.weight], dim=0)
I_gate_up = F.linear(flat, W_gate_up).reshape(TK, batch, 2 * D_ff)
I_skip = F.linear(flat, self.skip_proj.weight).reshape(TK, batch, D)
# ====== Phase 2: Gate+Up 合并 PLIF scan(row-param kernel) ======
beta_gate = self.gate_neuron.beta # (D_ff,)
beta_up = self.up_neuron.beta # (D_ff,)
surr = self.gate_neuron.surrogate_function
# u_merged: 向量缩放(D_ff 维 β 直接 cat,无需 expand)
scale_row = torch.cat([1.0 - beta_gate, 1.0 - beta_up]) # (2*D_ff,)
u_merged = I_gate_up * scale_row # (TK, batch, 2*D_ff), broadcast
# beta_row / v_th_row: (batch, 2*D_ff) — D_ff 维可学习参数
beta_row = torch.cat([beta_gate, beta_up]) # (2*D_ff,)
beta_row = beta_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous()
v_th_row = torch.cat([self.gate_neuron.v_th, self.up_neuron.v_th]) # (2*D_ff,)
v_th_row = v_th_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous()
# v_init_merged: (batch, 2*D_ff)
v_init_gate = self.gate_neuron.v
if isinstance(v_init_gate, float):
v_init_gate = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype)
v_init_up = self.up_neuron.v
if isinstance(v_init_up, float):
v_init_up = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype)
v_init_merged = torch.cat([v_init_gate, v_init_up], dim=-1)
# Row-param PLIF scan: beta/v_th 从寄存器读取,不占显存带宽
spike_merged, V_post_merged = plif_rowparam_forward(
beta_row, u_merged, v_th_row, v_init_merged,
surrogate_function=surr,
)
# 膜电位泄漏量作为激活值: leak = (1-β) · V_post
gate_leak = V_post_merged[:, :, :D_ff] * (1.0 - beta_gate) # (TK, batch, D_ff)
up_leak = V_post_merged[:, :, D_ff:] * (1.0 - beta_up) # (TK, batch, D_ff)
self.gate_neuron.v = V_post_merged[-1, :, :D_ff].detach()
self.up_neuron.v = V_post_merged[-1, :, D_ff:].detach()
# ====== Phase 3: 连续门控(leak × leak,对标 SwiGLU)+ 降维 ======
gated = gate_leak * up_leak # (TK, batch, D_ff)
gated_flat = gated.reshape(TK * batch, D_ff)
I_out = F.linear(gated_flat, self.down_proj.weight).reshape(TK, batch, D) + I_skip
# output_neuron 已移除:连续值由层级 K 帧聚合处理
return I_out # (TK, batch, D), 连续值
def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor:
"""
单步前向传播。
Args:
spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1}
Returns:
continuous_out: 连续输出, shape (batch, D)
"""
# 门控路径 — 膜电位泄漏量激活
_ = self.gate_neuron(self.gate_proj(spike_in))
gate_leak = (1.0 - self.gate_neuron.beta) * self.gate_neuron.v # leak
# 值路径 — 膜电位泄漏量激活
_ = self.up_neuron(self.up_proj(spike_in))
up_leak = (1.0 - self.up_neuron.beta) * self.up_neuron.v # leak
# 连续门控(对标 SwiGLU)
gated = gate_leak * up_leak
# 降维 + 残差
I_out = self.down_proj(gated) + self.skip_proj(spike_in) # R^D
return I_out # 连续值
|