File size: 9,915 Bytes
440e322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
"""
SNNBlock: 完整的 SNN 隐状态空间 Block(并行化版本)

结构(每个 SNN 时间步):
  spike_in {0,1}^D
    ├─→ W_in     → I[t] ∈ R^{D*N}
    ├─→ W_β^(x)  + b_β → σ      → β(t)
    ├─→ W_α^(x)  + b_α → softplus → α(t)
    ├─→ W_th^(x) + b_th → |·|+V_min → V_th(t)
    ├─→ W_gate   → sigmoid → gate ∈ (0,1)^D
    └─→ W_skip   → I_skip ∈ R^D

  SelectivePLIF(I, β, α, V_th) → s[t] ∈ {0,1}^{D*N}

  W_out · V_post[t] ⊙ gate + I_skip → 连续输出 ∈ R^D

数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import base, layer, surrogate

from .selective_plif import SelectivePLIFNode
from .parallel_scan import plif_parallel_forward


# ====== Fused modulation activations (torch.compile) ======
# Fuse sigmoid + softplus + abs + alpha*I into single kernel.
# 7-8 separate element-wise kernels → 1 fused kernel, ~4x speedup on DN-sized tensors.
# First call triggers JIT compilation (~seconds); cached for subsequent calls.

@torch.compile(backend='inductor', fullgraph=True)
def _fused_modulation(raw_beta, b_beta, raw_alpha, b_alpha, raw_th, b_th, v_th_min, I_all):
    beta = torch.sigmoid(raw_beta + b_beta)
    alpha = F.softplus(raw_alpha + b_alpha)
    v_th = v_th_min + torch.abs(raw_th + b_th)
    u = alpha * I_all
    return beta, u, v_th


class SNNBlock(base.MemoryModule):
    """
    单个 SNN Block(并行化)。

    Args:
        D: 可见维度(Block 间通信的维度)
        N: 状态扩展因子(每个通道的隐神经元数)
        v_th_min: 动态阈值下限
        surrogate_function: surrogate gradient 函数
    """

    def __init__(
        self,
        D: int,
        N: int = 8,
        v_th_min: float = 0.1,
        surrogate_function=surrogate.Sigmoid(alpha=4.0),
    ):
        super().__init__()
        self.D = D
        self.N = N
        self.v_th_min = v_th_min
        DN = D * N

        # ====== 六条并行输入投影(SNN 突触:spike 输入) ======
        self.W_in = layer.Linear(D, DN, bias=False, step_mode='s')
        self.W_beta_x = layer.Linear(D, DN, bias=False, step_mode='s')
        self.W_alpha_x = layer.Linear(D, DN, bias=False, step_mode='s')
        self.W_th_x = layer.Linear(D, DN, bias=False, step_mode='s')
        self.W_gate = layer.Linear(D, D, bias=False, step_mode='s')
        self.W_skip = layer.Linear(D, D, bias=False, step_mode='s')

        # ====== β/α/V_th 仅依赖 spike_in(无 W^(V)·V 项) ======

        # ====== 调制偏置(结构化初始化) ======
        self.b_beta = nn.Parameter(torch.empty(DN))
        self.b_alpha = nn.Parameter(torch.empty(DN))
        self.b_th = nn.Parameter(torch.empty(DN))

        # ====== 输出投影:D*N → D(SNN 突触) ======
        self.W_out = layer.Linear(DN, D, bias=False, step_mode='s')

        # ====== 隐状态空间神经元(D*N 个,动态参数) ======
        self.hidden_neuron = SelectivePLIFNode(
            surrogate_function=surrogate_function,
            detach_reset=False,
        )

        # ====== 参数初始化 ======
        self._initialize_parameters()

    def _initialize_parameters(self):
        """功能引导初始化。"""
        D, N = self.D, self.N
        K_ref = 16

        # 目标 β 分布:多时间尺度 [0.80, 0.99]
        beta_values = torch.linspace(0.80, 0.99, N)

        # ====== 1. β 偏置:logit-spaced + 维度间随机扰动 ======
        b_beta_per_n = torch.log(beta_values / (1.0 - beta_values))
        # 以 per_n 值为均值,加 N(0, 0.1) 扰动打破 D 个通道的对称性
        self.b_beta.data.copy_(b_beta_per_n.repeat(D))
        self.b_beta.data.add_(torch.empty_like(self.b_beta).normal_(0, 0.1))

        # ====== 2. α 偏置:softplus(0.5413) ≈ 1.0 + 维度间随机扰动 ======
        # 以 0.5413 为均值,N(0, 0.1) 扰动 → α ∈ ~[0.7, 1.3]
        self.b_alpha.data.normal_(0.5413, 0.1)

        # ====== 3. W^(x) 权重 ======
        for lin in [self.W_in, self.W_gate, self.W_skip, self.W_out]:
            nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
        for lin in [self.W_beta_x, self.W_alpha_x, self.W_th_x]:
            nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
            lin.weight.data.mul_(0.1)

        # ====== 4. W_in 时间尺度缩放 ======
        scale_per_n = torch.sqrt(1.0 - beta_values ** 2)  # (N,)
        scale_DN = scale_per_n.repeat(D)  # (D*N,)
        with torch.no_grad():
            self.W_in.weight.mul_(scale_DN.unsqueeze(1))

        # ====== 5. b_th:σ_V 校准 ======
        # σ_V = sqrt(p/3) * sqrt(1 - β^{2K})
        # 其中 p 是输入 firing rate。旧版假设 p=0.5(σ_I=0.408),
        # 但实际 input_neuron firing rate 约 0.07~0.45,深层更低。
        # 用 p=0.15 保守估计,避免 v_th 过高导致死神经元。
        p_assumed = 0.15
        sigma_I_base = math.sqrt(p_assumed / 3.0)
        sigma_V_per_n = sigma_I_base * torch.sqrt(
            1.0 - beta_values ** (2 * K_ref)
        )
        target_p_fire = torch.linspace(0.25, 0.08, N)
        z_scores = math.sqrt(2.0) * torch.erfinv(
            2.0 * (1.0 - target_p_fire) - 1.0
        )
        target_V_th = sigma_V_per_n * z_scores
        b_th_per_n = torch.clamp(target_V_th - self.v_th_min, min=0.05)
        # 以 per_n 值为均值,加 N(0, 0.02) 扰动打破 D 个通道的对称性
        self.b_th.data.copy_(b_th_per_n.repeat(D))
        self.b_th.data.add_(torch.empty_like(self.b_th).normal_(0, 0.02))

        # ====== 6. W_out 发放率均衡缩放 ======
        out_scale_per_n = 1.0 / torch.sqrt(target_p_fire)
        out_scale_per_n = out_scale_per_n / out_scale_per_n.mean()
        out_scale_DN = out_scale_per_n.repeat(D)
        with torch.no_grad():
            self.W_out.weight.mul_(out_scale_DN.unsqueeze(0))

    def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor:
        """
        并行前向传播:使用 parallel scan 处理全序列。

        Args:
            spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike

        Returns:
            continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出(V_post 经 W_out 投影)
        """
        TK, batch, D = spike_in_seq.shape
        DN = self.D * self.N

        # ====== Phase 1: 批量投影(全部 TK 帧同时计算)======
        flat = spike_in_seq.reshape(TK * batch, D)

        I_all = F.linear(flat, self.W_in.weight).reshape(TK, batch, DN)
        raw_beta = F.linear(flat, self.W_beta_x.weight).reshape(TK, batch, DN)
        raw_alpha = F.linear(flat, self.W_alpha_x.weight).reshape(TK, batch, DN)
        raw_th = F.linear(flat, self.W_th_x.weight).reshape(TK, batch, DN)
        gate_all = torch.sigmoid(
            F.linear(flat, self.W_gate.weight).reshape(TK, batch, D)
        )
        I_skip_all = F.linear(flat, self.W_skip.weight).reshape(TK, batch, D)

        # ====== Phase 1b: 融合激活(torch.compile → 单 kernel)======
        beta_all, u_hidden, v_th_all = _fused_modulation(
            raw_beta, self.b_beta, raw_alpha, self.b_alpha,
            raw_th, self.b_th, self.v_th_min, I_all,
        )

        # 获取隐神经元初始状态
        v_init_hidden = self.hidden_neuron.v
        if isinstance(v_init_hidden, float):
            v_init_hidden = torch.zeros(batch, DN, device=flat.device, dtype=flat.dtype)

        s_hidden, V_post_hidden, _ = plif_parallel_forward(
            beta_all, u_hidden, v_th_all, v_init_hidden, max_iter=3,
            surrogate_function=self.hidden_neuron.surrogate_function,
        )

        # 更新隐神经元状态(保存末步供下次调用)
        self.hidden_neuron.v = V_post_hidden[-1].detach()

        # ====== Phase 4: 输出投影(V_post → W_out: 连续梯度直通 β)======
        # 用 V_post(膜电压)代替 spike 作为 W_out 输入,消除 surrogate 梯度瓶颈:
        #   spike 路径: ∂spike/∂β = surrogate'(V-v_th) · V_prev ≈ 0(大部分时刻)
        #   V_post 路径: ∂V_post/∂β = V_prev(无 surrogate 阻断,每步都有梯度)
        v_flat = V_post_hidden.reshape(TK * batch, DN)
        I_out_all = F.linear(v_flat, self.W_out.weight).reshape(TK, batch, D)
        I_total_all = I_out_all * gate_all + I_skip_all  # (TK, batch, D)

        # output_neuron 已移除:连续值由层级 K 帧聚合处理
        return I_total_all  # (TK, batch, D), 连续值

    def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor:
        """
        单步前向传播(用于调试/兼容)。

        Args:
            spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1}

        Returns:
            continuous_out: 连续输出, shape (batch, D)
        """
        V_prev = self.hidden_neuron.v
        if isinstance(V_prev, float):
            V_prev = torch.zeros(
                spike_in.shape[0], self.D * self.N,
                device=spike_in.device, dtype=spike_in.dtype,
            )

        I_t = self.W_in(spike_in)

        # β 调制仅依赖 spike_in
        beta = torch.sigmoid(self.W_beta_x(spike_in) + self.b_beta)
        alpha = F.softplus(self.W_alpha_x(spike_in) + self.b_alpha)
        v_th = self.v_th_min + torch.abs(self.W_th_x(spike_in) + self.b_th)

        gate = torch.sigmoid(self.W_gate(spike_in))
        I_skip = self.W_skip(spike_in)

        s_hidden = self.hidden_neuron(I_t, beta, alpha, v_th)

        # 用 V_post(膜电压)做输出投影,与 forward_parallel 一致
        V_post = self.hidden_neuron.v  # 发放+重置后的膜电位
        I_out = self.W_out(V_post)
        I_total = I_out * gate + I_skip

        return I_total  # 连续值