File size: 14,323 Bytes
440e322 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 | """
SNNDecoderLayer: 单个 SNN 解码层(Pre-LN 连续残差流 + 动态 K 帧聚合)
RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → 残差
RMSNorm → PLIF → SNNFFN → 动态K聚合 → out_proj → 残差
动态 K:
- K 是最大步数(K_max),不是固定步数。不同 token 有效步数 ∈ [1, K_max]。
- 每个 token 的 K 帧 SNN 输出,学习自适应停止概率 p_halt
- PonderNet 几何分布加权:λ_k = p_k · ∏_{j<k}(1-p_j),归一化后加权聚合
- 不同 token 有效步数不同:简单 token 早停(E[K]小),复杂 token 用满步数
- ponder_cost 正则化:鼓励用更少步数完成简单 token 的处理
数学推导:
停止概率: p_k = σ(halt_proj(frame_k)) ∈ (0,1)
生存概率: S_k = ∏_{j=1}^{k-1} (1 - p_j) — 到第 k 步还没停
权重: λ_k = p_k · S_k — 恰好在第 k 步停止的概率
归一化: λ̂_k = λ_k / Σ_k λ_k — 确保权重和为 1
聚合: output = Σ_k λ̂_k · frame_k
代价: E[K] = Σ_k k · λ̂_k — 期望步数
K_max 设计原则:
K_max 越大,模型对复杂 token 的处理能力越强(更多步数可用),
但计算量和显存线性增长。K_max=32 允许 token 使用 1~32 步。
PonderNet 的 ponder_cost 正则化确保简单 token 不浪费步数。
K 帧层间聚合:
- SNN 子层输出 K 帧连续值(V_post 经投影),PonderNet 加权聚合为 1 per token
- 聚合后经 out_proj 投影,广播回 K 帧做残差
- 使 β 的时间动力学通过 K 帧聚合梯度有效传播
对标 Qwen3DecoderLayer(Pre-LN 模式完全等价):
Qwen3: RMSNorm → Attention → residual → RMSNorm → MLP → residual
SNN: RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → residual
→ RMSNorm → PLIF → SNNFFN → 动态K聚合 → out_proj → residual
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import base, surrogate
from .plif_node import PLIFNode
from .rms_norm import RMSNorm
from .snn_block import SNNBlock
from .snn_ffn import SNNFFN
from .parallel_scan import plif_rowparam_forward
# ====== Fused halt weight computation (torch.compile) ======
# 7-8 个独立 element-wise kernel → 单 fused kernel
# sigmoid + clamp + log1p + cumsum + exp + normalize
# 首次调用触发 JIT 编译(~秒级),后续调用走缓存
@torch.compile(backend='inductor', fullgraph=True)
def _fused_geometric_halt(halt_logits):
"""融合计算 PonderNet 几何分布停止权重。
输入: halt_logits (seq_len, K, batch) — halt_proj 的原始输出
输出: halt_weights (seq_len, K, batch) — 归一化几何分布权重,sum=1
数学: p_k = σ(logit_k), S_k = ∏_{j<k}(1-p_j), λ_k = p_k·S_k, λ̂_k = λ_k/Σλ
"""
p_halt = torch.sigmoid(halt_logits).clamp(min=1e-7, max=1.0 - 1e-7)
log_1_minus_p = torch.log1p(-p_halt) # (seq_len, K, batch)
# Exclusive cumsum: log_survive[:, k, :] = Σ_{j<k} log(1-p_j)
# 避免 torch.cat: 用 cumsum([:, :-1]) 填充 [:, 1:]
log_survive = torch.zeros_like(log_1_minus_p)
log_survive[:, 1:, :] = torch.cumsum(log_1_minus_p[:, :-1, :], dim=1)
survive = torch.exp(log_survive) # (seq_len, K, batch)
halt_weights = p_halt * survive # λ_k = p_k · S_k
halt_weights = halt_weights / (halt_weights.sum(dim=1, keepdim=True) + 1e-8)
return halt_weights
class SNNDecoderLayer(base.MemoryModule):
"""
单个 SNN 解码层(连续残差流 + K 帧聚合版本)。
层间传递连续值 h (TK, batch, D),通过 PLIF 神经元转换为 spike,
输入 SNN 子层处理后,K 帧聚合为 1 per token,经 out_proj 投影,
广播回 K 帧做残差连接。
K 帧聚合使 β 的时间动力学(控制 K 步内的膜电位演化)产生可微分的
token 级效应,解决 β 梯度为纯噪声的问题。
Args:
D: 可见维度
N: 状态扩展因子
D_ff: FFN 中间层维度
v_th_min: SNNBlock 动态阈值下限
ffn_v_threshold: SNNFFN gate/up 神经元阈值
K: 每 token 的 SNN 时间步数
num_layers: 总层数(用于残差输出缩放 + SNNFFN down_proj 缩放)
layer_idx: 当前层索引
"""
def __init__(
self,
D: int,
N: int,
D_ff: int,
v_th_min: float,
ffn_v_threshold: float,
K: int = 16,
num_layers: int = 1,
layer_idx: int = 0,
):
super().__init__()
self.D = D
self.K = K
self.snn_block = SNNBlock(
D=D, N=N, v_th_min=v_th_min,
)
self.snn_ffn = SNNFFN(
D=D, D_ff=D_ff,
output_v_threshold=ffn_v_threshold,
num_layers=num_layers,
layer_idx=layer_idx,
)
# Pre-LN 分支归一化: h → RMSNorm → PLIFNode
self.block_norm = RMSNorm(D)
self.ffn_norm = RMSNorm(D)
# 输入神经元: RMSNorm(h) → V_post 膜电位激活(D 维可学习 β 和 V_th)
self.input_neuron1 = PLIFNode(
dim=D,
init_tau=2.0,
v_threshold=0.5,
surrogate_function=surrogate.Sigmoid(alpha=4.0),
)
self.input_neuron2 = PLIFNode(
dim=D,
init_tau=2.0,
v_threshold=0.5,
surrogate_function=surrogate.Sigmoid(alpha=4.0),
)
# 输出投影(突触): spike (D) → 连续空间 (D)
self.block_out_proj = nn.Linear(D, D, bias=False)
self.ffn_out_proj = nn.Linear(D, D, bias=False)
# ====== 动态 K: 停止投影(突触: SNN 输出 → 停止概率) ======
# halt_proj: D → 1,每步每 token 产生一个停止 logit
# PonderNet 几何分布加权,替代 uniform mean 聚合
self.block_halt = nn.Linear(D, 1, bias=True)
self.ffn_halt = nn.Linear(D, 1, bias=True)
# 残差输出缩放初始化(GPT-2 style: σ = 0.02 / √(2·num_layers))
std = 0.02 / math.sqrt(2 * num_layers)
nn.init.normal_(self.block_out_proj.weight, std=std)
nn.init.normal_(self.ffn_out_proj.weight, std=std)
# halt 初始化: 小权重 + 负偏置 → p_halt ≈ 0.03 → 接近 uniform 聚合
# σ(-3.5) ≈ 0.029, 几何分布归一化后 λ_1/λ_K ≈ 1.5, 接近均匀
for halt in [self.block_halt, self.ffn_halt]:
nn.init.xavier_uniform_(halt.weight)
halt.weight.data.mul_(0.01)
nn.init.constant_(halt.bias, -3.5)
def _input_neuron_parallel(self, input_neuron, x):
"""
输入 PLIF 神经元的 parallel scan 前向传播。
完整 PLIF 动力学: V[t] = β·V[t-1] + (1-β)·x[t], spike = Θ(V-V_th), 软重置。
输出膜电位泄漏量 (1-β)·V_post 作为激活值——即每步因指数衰减将泄漏的量。
相比直接传递 V_post,泄漏量自然强调快响应神经元(大 1-β),
抑制慢记忆神经元(小 1-β),实现隐式的时间尺度加权。
Args:
input_neuron: PLIFNode 实例(D 维可学习 β 和 V_th)
x: (TK, batch, D) — 连续值输入
Returns:
leak: (TK, batch, D) — 膜电位泄漏量 (1-β)·V_post
"""
TK, batch, D = x.shape
beta = input_neuron.beta # (D,)
u = (1.0 - beta) * x # (D,) broadcast → (TK, batch, D)
v_init = input_neuron.v
if isinstance(v_init, float):
v_init = torch.zeros(batch, D, device=x.device, dtype=x.dtype)
beta_row = beta.unsqueeze(0).expand(batch, D).contiguous()
v_th_row = input_neuron.v_th.unsqueeze(0).expand(batch, D).contiguous()
spike, V_post = plif_rowparam_forward(
beta_row, u, v_th_row, v_init,
surrogate_function=input_neuron.surrogate_function,
)
input_neuron.v = V_post[-1].detach()
return (1.0 - beta) * V_post # 膜电位泄漏量
def _adaptive_aggregate(self, frames, halt_proj):
"""
PonderNet 式自适应 K 帧聚合(动态 K 核心,torch.compile 融合优化)。
每步计算停止概率 p_k,用几何分布权重加权聚合,
使不同 token 有不同的有效步数。
优化: _fused_geometric_halt 将 sigmoid+log1p+cumsum+exp+normalize
融合为单 inductor kernel(参见 snn_block._fused_modulation 同一模式)。
数学:
p_k = σ(halt_proj(frame_k)) — 停止概率
S_k = ∏_{j<k} (1-p_j) — 生存概率
λ_k = p_k · S_k — 几何分布权重
λ̂_k = λ_k / Σ λ_k — 归一化
output = Σ λ̂_k · frame_k — 加权聚合
E[K] = Σ k · λ̂_k — 期望步数(ponder cost)
Args:
frames: (seq_len, K, batch, D) — SNN 子层 K 帧输出
halt_proj: nn.Linear(D, 1) — 停止投影(突触)
Returns:
aggregated: (seq_len, batch, D) — 加权聚合结果
ponder_cost: scalar — 期望步数均值(正则化用)
"""
seq_len, K, batch, D = frames.shape
# ====== 1. halt_proj matmul(cuBLAS)+ 融合几何权重(inductor) ======
halt_logits = halt_proj(frames).squeeze(-1) # (seq_len, K, batch)
halt_weights = _fused_geometric_halt(halt_logits) # (seq_len, K, batch), 归一化
# ====== 2. 加权聚合 ======
# (seq_len, K, batch, 1) × (seq_len, K, batch, D) → sum → (seq_len, batch, D)
aggregated = (frames * halt_weights.unsqueeze(-1)).sum(dim=1)
# ====== 3. Ponder cost: E[K] per token ======
steps = torch.arange(1, K + 1, device=frames.device, dtype=frames.dtype)
expected_k = (halt_weights * steps[None, :, None]).sum(dim=1) # (seq_len, batch)
ponder_cost = expected_k.mean() # scalar
return aggregated, ponder_cost, expected_k.detach()
def forward_parallel(self, h):
"""
并行前向传播:连续残差流 + 动态 K 帧聚合。
SNN 子层在 TK 维度处理(K 步时间动力学),输出后用 PonderNet
自适应聚合 K 帧(不同 token 有效步数不同),经 out_proj 投影后
广播回 TK 做残差。
Args:
h: (TK, batch, D) — 连续值输入
Returns:
h: (TK, batch, D) — 连续值输出
ponder_cost: scalar — 两个子层的平均期望步数(正则化用)
"""
TK, batch, D = h.shape
K = self.K
seq_len = TK // K
# 子层 1: SNNBlock — RMSNorm → PLIFNode(V_post) → SNNBlock → 动态K聚合 → out_proj → 残差
v_in = self._input_neuron_parallel(self.input_neuron1, self.block_norm(h))
cont_block = self.snn_block.forward_parallel(v_in) # (TK, batch, D), 连续值
# 动态 K 帧聚合(PonderNet): (TK, batch, D) → (seq_len, K, batch, D) → 加权 → (seq_len, batch, D)
frames_block = cont_block.view(seq_len, K, batch, D)
combined_block, pc_block, ek_block = self._adaptive_aggregate(frames_block, self.block_halt)
res_block = self.block_out_proj(combined_block) # (seq_len, batch, D)
res_block = res_block - res_block.mean(dim=-1, keepdim=True) # 残差中心化
# 广播回 TK:每 token 的残差复制 K 份
h = h + res_block.repeat_interleave(K, dim=0)
# 子层 2: SNNFFN — RMSNorm → PLIFNode(V_post) → SNNFFN → 动态K聚合 → out_proj → 残差
v_in2 = self._input_neuron_parallel(self.input_neuron2, self.ffn_norm(h))
cont_ffn = self.snn_ffn.forward_parallel(v_in2) # (TK, batch, D), 连续值
frames_ffn = cont_ffn.view(seq_len, K, batch, D)
combined_ffn, pc_ffn, ek_ffn = self._adaptive_aggregate(frames_ffn, self.ffn_halt)
res_ffn = self.ffn_out_proj(combined_ffn)
res_ffn = res_ffn - res_ffn.mean(dim=-1, keepdim=True)
h = h + res_ffn.repeat_interleave(K, dim=0)
ponder_cost = (pc_block + pc_ffn) / 2.0 # 两个子层平均
# 存储 per-token E[K] 范围(诊断用,不影响计算图)
# ek_block/ek_ffn: (seq_len, batch), detached
with torch.no_grad():
all_ek = torch.cat([ek_block.flatten(), ek_ffn.flatten()])
self._ek_min = all_ek.min().item()
self._ek_max = all_ek.max().item()
return h, ponder_cost
def single_step_forward(self, h):
"""
单步前向传播:连续残差流。
注意:单步模式无法做动态 K 聚合(每步独立处理)。
训练和推理均使用 forward_parallel(含动态 K 聚合)。
此方法仅用于调试。
Args:
h: (batch, D) — 连续值输入
Returns:
h: (batch, D) — 连续值输出
ponder_cost: scalar — 0.0(单步无 ponder cost)
"""
# 子层 1: SNNBlock — RMSNorm → PLIFNode(leak) → SNNBlock → out_proj → 残差
_ = self.input_neuron1(self.block_norm(h)) # 触发 PLIF 动力学,更新 .v
v_in = (1.0 - self.input_neuron1.beta) * self.input_neuron1.v # 膜电位泄漏量
cont_block = self.snn_block.single_step_forward(v_in)
res_block = self.block_out_proj(cont_block)
h = h + res_block - res_block.mean(dim=-1, keepdim=True)
# 子层 2: SNNFFN — RMSNorm → PLIFNode(leak) → SNNFFN → out_proj → 残差
_ = self.input_neuron2(self.ffn_norm(h))
v_in2 = (1.0 - self.input_neuron2.beta) * self.input_neuron2.v # 膜电位泄漏量
cont_ffn = self.snn_ffn.single_step_forward(v_in2)
res_ffn = self.ffn_out_proj(cont_ffn)
h = h + res_ffn - res_ffn.mean(dim=-1, keepdim=True)
return h, torch.tensor(0.0, device=h.device)
|