| """ |
| SNNFFN: SNN 等价的 Feed-Forward Network |
| |
| 对标 Qwen3MLP 的 SwiGLU 结构: |
| Qwen3 MLP: down_proj( SiLU(gate_proj(x)) * up_proj(x) ) |
| SNN FFN: down_proj( gate_V_post * up_V_post ) + skip |
| |
| 膜电位门控(对标 SiLU gating): |
| gate/up 神经元完整 PLIF 动力学(积分+阈值+重置), |
| 输出膜电位 V_post 做连续乘法门控,替代 binary AND 门。 |
| |
| 信号流: |
| x → gate_proj → gate_neuron → V_post_gate |
| x → up_proj → up_neuron → V_post_up |
| V_post_gate × V_post_up → gated |
| down_proj(gated) + skip_proj(x) → 连续输出 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from spikingjelly.activation_based import base, layer, surrogate |
|
|
| from .plif_node import PLIFNode |
| from .parallel_scan import plif_rowparam_forward |
|
|
|
|
| class SNNFFN(base.MemoryModule): |
| """ |
| SNN 等价的 Feed-Forward Network。 |
| |
| Args: |
| D: 可见维度(输入/输出 spike 维度) |
| D_ff: 中间层维度(对标 Qwen3 intermediate_size) |
| output_v_threshold: 输出神经元阈值 |
| num_layers: 总层数,用于 down_proj 缩放 |
| layer_idx: 当前层索引 |
| surrogate_function: surrogate gradient 函数 |
| """ |
|
|
| def __init__( |
| self, |
| D: int, |
| D_ff: int, |
| output_v_threshold: float = 0.3, |
| num_layers: int = 1, |
| layer_idx: int = 0, |
| surrogate_function=surrogate.Sigmoid(alpha=4.0), |
| ): |
| super().__init__() |
| self.D = D |
| self.D_ff = D_ff |
|
|
| |
| self.gate_proj = layer.Linear(D, D_ff, bias=False, step_mode='s') |
| self.up_proj = layer.Linear(D, D_ff, bias=False, step_mode='s') |
| self.down_proj = layer.Linear(D_ff, D, bias=False, step_mode='s') |
|
|
| |
| self.skip_proj = layer.Linear(D, D, bias=False, step_mode='s') |
|
|
| |
| |
| self.gate_neuron = PLIFNode( |
| dim=D_ff, |
| init_tau=2.0, |
| v_threshold=output_v_threshold, |
| surrogate_function=surrogate_function, |
| ) |
| |
| self.up_neuron = PLIFNode( |
| dim=D_ff, |
| init_tau=2.0, |
| v_threshold=output_v_threshold, |
| surrogate_function=surrogate_function, |
| ) |
| |
| self._initialize_parameters(num_layers) |
|
|
| def _initialize_parameters(self, num_layers: int): |
| """初始化投影权重。 |
| |
| - gate_proj, up_proj, skip_proj: Kaiming uniform |
| - down_proj: Kaiming uniform × 1/√(num_layers),防深层梯度爆炸 |
| """ |
| for lin in [self.gate_proj, self.up_proj, self.skip_proj]: |
| nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5)) |
|
|
| nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) |
| self.down_proj.weight.data.mul_(1.0 / math.sqrt(num_layers)) |
|
|
| def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor: |
| """ |
| 并行前向传播:使用 parallel scan 处理全序列。 |
| |
| 优化: |
| - gate_proj + up_proj 合并为单次 matmul(2 launch → 1) |
| - gate + up PLIF scan: row-param kernel(无需 expand+contiguous beta/v_th) |
| - u_merged: 向量缩放替代 cat(1次 broadcast multiply 替代 2次 scale + 1次 cat) |
| |
| Args: |
| spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike |
| |
| Returns: |
| continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出 |
| """ |
| TK, batch, D = spike_in_seq.shape |
| D_ff = self.D_ff |
| flat = spike_in_seq.reshape(TK * batch, D) |
|
|
| |
| W_gate_up = torch.cat([self.gate_proj.weight, self.up_proj.weight], dim=0) |
| I_gate_up = F.linear(flat, W_gate_up).reshape(TK, batch, 2 * D_ff) |
| I_skip = F.linear(flat, self.skip_proj.weight).reshape(TK, batch, D) |
|
|
| |
| beta_gate = self.gate_neuron.beta |
| beta_up = self.up_neuron.beta |
| surr = self.gate_neuron.surrogate_function |
|
|
| |
| scale_row = torch.cat([1.0 - beta_gate, 1.0 - beta_up]) |
| u_merged = I_gate_up * scale_row |
|
|
| |
| beta_row = torch.cat([beta_gate, beta_up]) |
| beta_row = beta_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous() |
|
|
| v_th_row = torch.cat([self.gate_neuron.v_th, self.up_neuron.v_th]) |
| v_th_row = v_th_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous() |
|
|
| |
| v_init_gate = self.gate_neuron.v |
| if isinstance(v_init_gate, float): |
| v_init_gate = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype) |
| v_init_up = self.up_neuron.v |
| if isinstance(v_init_up, float): |
| v_init_up = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype) |
| v_init_merged = torch.cat([v_init_gate, v_init_up], dim=-1) |
|
|
| |
| spike_merged, V_post_merged = plif_rowparam_forward( |
| beta_row, u_merged, v_th_row, v_init_merged, |
| surrogate_function=surr, |
| ) |
|
|
| |
| gate_leak = V_post_merged[:, :, :D_ff] * (1.0 - beta_gate) |
| up_leak = V_post_merged[:, :, D_ff:] * (1.0 - beta_up) |
| self.gate_neuron.v = V_post_merged[-1, :, :D_ff].detach() |
| self.up_neuron.v = V_post_merged[-1, :, D_ff:].detach() |
|
|
| |
| gated = gate_leak * up_leak |
| gated_flat = gated.reshape(TK * batch, D_ff) |
| I_out = F.linear(gated_flat, self.down_proj.weight).reshape(TK, batch, D) + I_skip |
|
|
| |
| return I_out |
|
|
| def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor: |
| """ |
| 单步前向传播。 |
| |
| Args: |
| spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1} |
| |
| Returns: |
| continuous_out: 连续输出, shape (batch, D) |
| """ |
| |
| _ = self.gate_neuron(self.gate_proj(spike_in)) |
| gate_leak = (1.0 - self.gate_neuron.beta) * self.gate_neuron.v |
|
|
| |
| _ = self.up_neuron(self.up_proj(spike_in)) |
| up_leak = (1.0 - self.up_neuron.beta) * self.up_neuron.v |
|
|
| |
| gated = gate_leak * up_leak |
|
|
| |
| I_out = self.down_proj(gated) + self.skip_proj(spike_in) |
| return I_out |
|
|