File size: 14,323 Bytes
440e322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
"""
SNNDecoderLayer: 单个 SNN 解码层(Pre-LN 连续残差流 + 动态 K 帧聚合)

  RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → 残差
  RMSNorm → PLIF → SNNFFN   → 动态K聚合 → out_proj → 残差

动态 K:
  - K 是最大步数(K_max),不是固定步数。不同 token 有效步数 ∈ [1, K_max]。
  - 每个 token 的 K 帧 SNN 输出,学习自适应停止概率 p_halt
  - PonderNet 几何分布加权:λ_k = p_k · ∏_{j<k}(1-p_j),归一化后加权聚合
  - 不同 token 有效步数不同:简单 token 早停(E[K]小),复杂 token 用满步数
  - ponder_cost 正则化:鼓励用更少步数完成简单 token 的处理

  数学推导:
    停止概率: p_k = σ(halt_proj(frame_k))         ∈ (0,1)
    生存概率: S_k = ∏_{j=1}^{k-1} (1 - p_j)       — 到第 k 步还没停
    权重:     λ_k = p_k · S_k                       — 恰好在第 k 步停止的概率
    归一化:   λ̂_k = λ_k / Σ_k λ_k                   — 确保权重和为 1
    聚合:     output = Σ_k λ̂_k · frame_k
    代价:     E[K] = Σ_k k · λ̂_k                    — 期望步数

  K_max 设计原则:
    K_max 越大,模型对复杂 token 的处理能力越强(更多步数可用),
    但计算量和显存线性增长。K_max=32 允许 token 使用 1~32 步。
    PonderNet 的 ponder_cost 正则化确保简单 token 不浪费步数。

K 帧层间聚合:
  - SNN 子层输出 K 帧连续值(V_post 经投影),PonderNet 加权聚合为 1 per token
  - 聚合后经 out_proj 投影,广播回 K 帧做残差
  - 使 β 的时间动力学通过 K 帧聚合梯度有效传播

对标 Qwen3DecoderLayer(Pre-LN 模式完全等价):
  Qwen3:  RMSNorm → Attention → residual → RMSNorm → MLP → residual
  SNN:    RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → residual
        → RMSNorm → PLIF → SNNFFN   → 动态K聚合 → out_proj → residual
"""

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import base, surrogate

from .plif_node import PLIFNode
from .rms_norm import RMSNorm
from .snn_block import SNNBlock
from .snn_ffn import SNNFFN
from .parallel_scan import plif_rowparam_forward


# ====== Fused halt weight computation (torch.compile) ======
# 7-8 个独立 element-wise kernel → 单 fused kernel
# sigmoid + clamp + log1p + cumsum + exp + normalize
# 首次调用触发 JIT 编译(~秒级),后续调用走缓存

@torch.compile(backend='inductor', fullgraph=True)
def _fused_geometric_halt(halt_logits):
    """融合计算 PonderNet 几何分布停止权重。

    输入: halt_logits (seq_len, K, batch) — halt_proj 的原始输出
    输出: halt_weights (seq_len, K, batch) — 归一化几何分布权重,sum=1

    数学: p_k = σ(logit_k), S_k = ∏_{j<k}(1-p_j), λ_k = p_k·S_k, λ̂_k = λ_k/Σλ
    """
    p_halt = torch.sigmoid(halt_logits).clamp(min=1e-7, max=1.0 - 1e-7)
    log_1_minus_p = torch.log1p(-p_halt)               # (seq_len, K, batch)
    # Exclusive cumsum: log_survive[:, k, :] = Σ_{j<k} log(1-p_j)
    # 避免 torch.cat: 用 cumsum([:, :-1]) 填充 [:, 1:]
    log_survive = torch.zeros_like(log_1_minus_p)
    log_survive[:, 1:, :] = torch.cumsum(log_1_minus_p[:, :-1, :], dim=1)
    survive = torch.exp(log_survive)                    # (seq_len, K, batch)
    halt_weights = p_halt * survive                     # λ_k = p_k · S_k
    halt_weights = halt_weights / (halt_weights.sum(dim=1, keepdim=True) + 1e-8)
    return halt_weights


class SNNDecoderLayer(base.MemoryModule):
    """
    单个 SNN 解码层(连续残差流 + K 帧聚合版本)。

    层间传递连续值 h (TK, batch, D),通过 PLIF 神经元转换为 spike,
    输入 SNN 子层处理后,K 帧聚合为 1 per token,经 out_proj 投影,
    广播回 K 帧做残差连接。

    K 帧聚合使 β 的时间动力学(控制 K 步内的膜电位演化)产生可微分的
    token 级效应,解决 β 梯度为纯噪声的问题。

    Args:
        D: 可见维度
        N: 状态扩展因子
        D_ff: FFN 中间层维度
        v_th_min: SNNBlock 动态阈值下限
        ffn_v_threshold: SNNFFN gate/up 神经元阈值
        K: 每 token 的 SNN 时间步数
        num_layers: 总层数(用于残差输出缩放 + SNNFFN down_proj 缩放)
        layer_idx: 当前层索引
    """

    def __init__(
        self,
        D: int,
        N: int,
        D_ff: int,
        v_th_min: float,
        ffn_v_threshold: float,
        K: int = 16,
        num_layers: int = 1,
        layer_idx: int = 0,
    ):
        super().__init__()
        self.D = D
        self.K = K

        self.snn_block = SNNBlock(
            D=D, N=N, v_th_min=v_th_min,
        )
        self.snn_ffn = SNNFFN(
            D=D, D_ff=D_ff,
            output_v_threshold=ffn_v_threshold,
            num_layers=num_layers,
            layer_idx=layer_idx,
        )

        # Pre-LN 分支归一化: h → RMSNorm → PLIFNode
        self.block_norm = RMSNorm(D)
        self.ffn_norm = RMSNorm(D)

        # 输入神经元: RMSNorm(h) → V_post 膜电位激活(D 维可学习 β 和 V_th)
        self.input_neuron1 = PLIFNode(
            dim=D,
            init_tau=2.0,
            v_threshold=0.5,
            surrogate_function=surrogate.Sigmoid(alpha=4.0),
        )
        self.input_neuron2 = PLIFNode(
            dim=D,
            init_tau=2.0,
            v_threshold=0.5,
            surrogate_function=surrogate.Sigmoid(alpha=4.0),
        )

        # 输出投影(突触): spike (D) → 连续空间 (D)
        self.block_out_proj = nn.Linear(D, D, bias=False)
        self.ffn_out_proj = nn.Linear(D, D, bias=False)

        # ====== 动态 K: 停止投影(突触: SNN 输出 → 停止概率) ======
        # halt_proj: D → 1,每步每 token 产生一个停止 logit
        # PonderNet 几何分布加权,替代 uniform mean 聚合
        self.block_halt = nn.Linear(D, 1, bias=True)
        self.ffn_halt = nn.Linear(D, 1, bias=True)

        # 残差输出缩放初始化(GPT-2 style: σ = 0.02 / √(2·num_layers))
        std = 0.02 / math.sqrt(2 * num_layers)
        nn.init.normal_(self.block_out_proj.weight, std=std)
        nn.init.normal_(self.ffn_out_proj.weight, std=std)

        # halt 初始化: 小权重 + 负偏置 → p_halt ≈ 0.03 → 接近 uniform 聚合
        # σ(-3.5) ≈ 0.029, 几何分布归一化后 λ_1/λ_K ≈ 1.5, 接近均匀
        for halt in [self.block_halt, self.ffn_halt]:
            nn.init.xavier_uniform_(halt.weight)
            halt.weight.data.mul_(0.01)
            nn.init.constant_(halt.bias, -3.5)

    def _input_neuron_parallel(self, input_neuron, x):
        """
        输入 PLIF 神经元的 parallel scan 前向传播。

        完整 PLIF 动力学: V[t] = β·V[t-1] + (1-β)·x[t], spike = Θ(V-V_th), 软重置。
        输出膜电位泄漏量 (1-β)·V_post 作为激活值——即每步因指数衰减将泄漏的量。
        相比直接传递 V_post,泄漏量自然强调快响应神经元(大 1-β),
        抑制慢记忆神经元(小 1-β),实现隐式的时间尺度加权。

        Args:
            input_neuron: PLIFNode 实例(D 维可学习 β 和 V_th)
            x: (TK, batch, D) — 连续值输入

        Returns:
            leak: (TK, batch, D) — 膜电位泄漏量 (1-β)·V_post
        """
        TK, batch, D = x.shape

        beta = input_neuron.beta  # (D,)
        u = (1.0 - beta) * x  # (D,) broadcast → (TK, batch, D)

        v_init = input_neuron.v
        if isinstance(v_init, float):
            v_init = torch.zeros(batch, D, device=x.device, dtype=x.dtype)

        beta_row = beta.unsqueeze(0).expand(batch, D).contiguous()
        v_th_row = input_neuron.v_th.unsqueeze(0).expand(batch, D).contiguous()

        spike, V_post = plif_rowparam_forward(
            beta_row, u, v_th_row, v_init,
            surrogate_function=input_neuron.surrogate_function,
        )

        input_neuron.v = V_post[-1].detach()
        return (1.0 - beta) * V_post  # 膜电位泄漏量

    def _adaptive_aggregate(self, frames, halt_proj):
        """
        PonderNet 式自适应 K 帧聚合(动态 K 核心,torch.compile 融合优化)。

        每步计算停止概率 p_k,用几何分布权重加权聚合,
        使不同 token 有不同的有效步数。

        优化: _fused_geometric_halt 将 sigmoid+log1p+cumsum+exp+normalize
        融合为单 inductor kernel(参见 snn_block._fused_modulation 同一模式)。

        数学:
          p_k = σ(halt_proj(frame_k))                 — 停止概率
          S_k = ∏_{j<k} (1-p_j)                       — 生存概率
          λ_k = p_k · S_k                             — 几何分布权重
          λ̂_k = λ_k / Σ λ_k                           — 归一化
          output = Σ λ̂_k · frame_k                    — 加权聚合
          E[K] = Σ k · λ̂_k                            — 期望步数(ponder cost)

        Args:
            frames: (seq_len, K, batch, D) — SNN 子层 K 帧输出
            halt_proj: nn.Linear(D, 1)    — 停止投影(突触)

        Returns:
            aggregated: (seq_len, batch, D) — 加权聚合结果
            ponder_cost: scalar             — 期望步数均值(正则化用)
        """
        seq_len, K, batch, D = frames.shape

        # ====== 1. halt_proj matmul(cuBLAS)+ 融合几何权重(inductor) ======
        halt_logits = halt_proj(frames).squeeze(-1)    # (seq_len, K, batch)
        halt_weights = _fused_geometric_halt(halt_logits)  # (seq_len, K, batch), 归一化

        # ====== 2. 加权聚合 ======
        # (seq_len, K, batch, 1) × (seq_len, K, batch, D) → sum → (seq_len, batch, D)
        aggregated = (frames * halt_weights.unsqueeze(-1)).sum(dim=1)

        # ====== 3. Ponder cost: E[K] per token ======
        steps = torch.arange(1, K + 1, device=frames.device, dtype=frames.dtype)
        expected_k = (halt_weights * steps[None, :, None]).sum(dim=1)  # (seq_len, batch)
        ponder_cost = expected_k.mean()               # scalar

        return aggregated, ponder_cost, expected_k.detach()

    def forward_parallel(self, h):
        """
        并行前向传播:连续残差流 + 动态 K 帧聚合。

        SNN 子层在 TK 维度处理(K 步时间动力学),输出后用 PonderNet
        自适应聚合 K 帧(不同 token 有效步数不同),经 out_proj 投影后
        广播回 TK 做残差。

        Args:
            h: (TK, batch, D) — 连续值输入

        Returns:
            h: (TK, batch, D) — 连续值输出
            ponder_cost: scalar — 两个子层的平均期望步数(正则化用)
        """
        TK, batch, D = h.shape
        K = self.K
        seq_len = TK // K

        # 子层 1: SNNBlock — RMSNorm → PLIFNode(V_post) → SNNBlock → 动态K聚合 → out_proj → 残差
        v_in = self._input_neuron_parallel(self.input_neuron1, self.block_norm(h))
        cont_block = self.snn_block.forward_parallel(v_in)  # (TK, batch, D), 连续值

        # 动态 K 帧聚合(PonderNet): (TK, batch, D) → (seq_len, K, batch, D) → 加权 → (seq_len, batch, D)
        frames_block = cont_block.view(seq_len, K, batch, D)
        combined_block, pc_block, ek_block = self._adaptive_aggregate(frames_block, self.block_halt)
        res_block = self.block_out_proj(combined_block)  # (seq_len, batch, D)
        res_block = res_block - res_block.mean(dim=-1, keepdim=True)  # 残差中心化

        # 广播回 TK:每 token 的残差复制 K 份
        h = h + res_block.repeat_interleave(K, dim=0)

        # 子层 2: SNNFFN — RMSNorm → PLIFNode(V_post) → SNNFFN → 动态K聚合 → out_proj → 残差
        v_in2 = self._input_neuron_parallel(self.input_neuron2, self.ffn_norm(h))
        cont_ffn = self.snn_ffn.forward_parallel(v_in2)  # (TK, batch, D), 连续值

        frames_ffn = cont_ffn.view(seq_len, K, batch, D)
        combined_ffn, pc_ffn, ek_ffn = self._adaptive_aggregate(frames_ffn, self.ffn_halt)
        res_ffn = self.ffn_out_proj(combined_ffn)
        res_ffn = res_ffn - res_ffn.mean(dim=-1, keepdim=True)

        h = h + res_ffn.repeat_interleave(K, dim=0)

        ponder_cost = (pc_block + pc_ffn) / 2.0  # 两个子层平均

        # 存储 per-token E[K] 范围(诊断用,不影响计算图)
        # ek_block/ek_ffn: (seq_len, batch), detached
        with torch.no_grad():
            all_ek = torch.cat([ek_block.flatten(), ek_ffn.flatten()])
            self._ek_min = all_ek.min().item()
            self._ek_max = all_ek.max().item()

        return h, ponder_cost

    def single_step_forward(self, h):
        """
        单步前向传播:连续残差流。

        注意:单步模式无法做动态 K 聚合(每步独立处理)。
        训练和推理均使用 forward_parallel(含动态 K 聚合)。
        此方法仅用于调试。

        Args:
            h: (batch, D) — 连续值输入

        Returns:
            h: (batch, D) — 连续值输出
            ponder_cost: scalar — 0.0(单步无 ponder cost)
        """
        # 子层 1: SNNBlock — RMSNorm → PLIFNode(leak) → SNNBlock → out_proj → 残差
        _ = self.input_neuron1(self.block_norm(h))  # 触发 PLIF 动力学,更新 .v
        v_in = (1.0 - self.input_neuron1.beta) * self.input_neuron1.v  # 膜电位泄漏量
        cont_block = self.snn_block.single_step_forward(v_in)
        res_block = self.block_out_proj(cont_block)
        h = h + res_block - res_block.mean(dim=-1, keepdim=True)

        # 子层 2: SNNFFN — RMSNorm → PLIFNode(leak) → SNNFFN → out_proj → 残差
        _ = self.input_neuron2(self.ffn_norm(h))
        v_in2 = (1.0 - self.input_neuron2.beta) * self.input_neuron2.v  # 膜电位泄漏量
        cont_ffn = self.snn_ffn.single_step_forward(v_in2)
        res_ffn = self.ffn_out_proj(cont_ffn)
        h = h + res_ffn - res_ffn.mean(dim=-1, keepdim=True)

        return h, torch.tensor(0.0, device=h.device)