Brain2nd commited on
Commit
46977a8
·
verified ·
1 Parent(s): 60f0126

Initial release: NeuronSpark-0.9B pretrained SNN language model

Browse files
LICENSE ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to the Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by the Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding any notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ Copyright 2025 Zhengzheng Tang
179
+
180
+ Licensed under the Apache License, Version 2.0 (the "License");
181
+ you may not use this file except in compliance with the License.
182
+ You may obtain a copy of the License at
183
+
184
+ http://www.apache.org/licenses/LICENSE-2.0
185
+
186
+ Unless required by applicable law or agreed to in writing, software
187
+ distributed under the License is distributed on an "AS IS" BASIS,
188
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
189
+ See the License for the specific language governing permissions and
190
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - zh
5
+ library_name: transformers
6
+ tags:
7
+ - snn
8
+ - spiking-neural-network
9
+ - text-generation
10
+ - neuromorphic
11
+ pipeline_tag: text-generation
12
+ ---
13
+ # NeuronSpark-0.9B
14
+
15
+ ## Introduction
16
+
17
+ **NeuronSpark-0.9B** is a **0.87-billion parameter language model built entirely on Spiking Neural Networks (SNNs)**. Unlike conventional Transformer-based LLMs that rely on attention mechanisms, NeuronSpark replaces the entire computation backbone with biologically-inspired spiking neurons, achieving language modeling through membrane potential dynamics, surrogate gradient training, and adaptive computation (PonderNet).
18
+
19
+ This is the **pretrained base model** (85,000 steps on a small subset of Seq-Monkey corpus).
20
+
21
+ > **Note on training data**: Due to limited compute resources (single DGX Spark), this model was trained on only **~85K steps with a small fraction of the full Seq-Monkey 10B-token corpus**. Despite the minimal training data, the model demonstrates emergent language capabilities — validating the architectural viability of pure SNN language models. We plan to continue scaling with more data and compute in future work.
22
+
23
+ For the instruction-tuned chat version, see [NeuronSpark-0.9B-Chat](https://modelscope.cn/models/Brain2nd/NeuronSpark-0.9B-Chat).
24
+
25
+ ## Model Details
26
+
27
+ | Attribute | Value |
28
+ |-----------|-------|
29
+ | Parameters | 874M |
30
+ | Architecture | SNN Hidden State Space Model |
31
+ | Hidden Dimension (D) | 896 |
32
+ | Layers | 20 |
33
+ | SNN Timesteps (K) | 16 (PonderNet adaptive) |
34
+ | State Expansion (N) | 8 |
35
+ | FFN Dimension | 2688 |
36
+ | Vocabulary | 6144 (custom BPE) |
37
+ | Context Length | 512 tokens |
38
+ | Training Data | Seq-Monkey (small subset, Chinese) |
39
+ | Training Tokens | ~1.4B (of ~10B available) |
40
+ | Precision | bfloat16 |
41
+ | License | Apache 2.0 |
42
+
43
+ ## Architecture Highlights
44
+
45
+ - **Pure SNN**: No attention, no standard MLP — all computation via PLIF (Parametric Leaky Integrate-and-Fire) neurons
46
+ - **Membrane Potential Leakage Activation**: PLIFNode outputs `(1-β)·V_post` (leak current), naturally emphasizing fast-responding neurons over slow-memory neurons
47
+ - **Selective State Space**: Hidden neurons with input-dependent dynamic β(t), α(t), V_th(t) — analogous to selective state space models (Mamba)
48
+ - **PonderNet Adaptive K**: Each token dynamically decides how many SNN timesteps to use (1~K), with geometric distribution weighting
49
+ - **Triton Fused Kernels**: Custom PLIF forward/backward kernels, single-pass sequential scan replacing 3-phase approach
50
+ - **Pre-LN Residual Stream**: Continuous residual flow with RMSNorm, matching Qwen3/LLaMA architecture pattern
51
+
52
+ ## Quickstart
53
+
54
+ ```python
55
+ from transformers import AutoModelForCausalLM, AutoTokenizer
56
+
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ "Brain2nd/NeuronSpark-0.9B",
59
+ trust_remote_code=True,
60
+ )
61
+ tokenizer = AutoTokenizer.from_pretrained("Brain2nd/NeuronSpark-0.9B")
62
+
63
+ # Text completion
64
+ text = f"{tokenizer.bos_token}人工智能的发展"
65
+ input_ids = tokenizer(text, return_tensors="pt")["input_ids"]
66
+
67
+ output_ids = model.generate(
68
+ input_ids,
69
+ max_new_tokens=128,
70
+ temperature=0.8,
71
+ top_k=50,
72
+ eos_token_id=tokenizer.eos_token_id,
73
+ )
74
+ print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
75
+ ```
76
+
77
+ **Example Output:**
78
+ ```
79
+ 人工智能的发展,为人类的未来发展提供了新的机遇。在未来,人工智能将是未来人工智能发展的重要方向。
80
+ ```
81
+
82
+ ## Requirements
83
+
84
+ ```bash
85
+ pip install torch transformers spikingjelly safetensors
86
+ # For Triton kernels (GPU): pip install triton
87
+ ```
88
+
89
+ ## Training
90
+
91
+ Trained on a single NVIDIA DGX Spark (GB10, 128GB unified memory) with 4-GPU DDP.
92
+ Due to compute constraints, training used only a small subset of the full corpus (~85K steps, ~1.4B tokens of ~10B available). Even with this limited data budget, the model acquires basic language generation ability, demonstrating the architectural viability of pure SNN language modeling.
93
+
94
+ ```bash
95
+ torchrun --nproc_per_node=4 train_ddp.py \
96
+ --D 896 --D_ff 2688 --K 16 --num_layers 20 \
97
+ --batch_size 8 --accumulation_steps 8 \
98
+ --learning_rate 2e-4 --warmup_iters 1000
99
+ ```
100
+
101
+ ## Citation
102
+
103
+ ```bibtex
104
+ @misc{neuronspark2025,
105
+ title={NeuronSpark: A Spiking Neural Network Language Model with Selective State Space Dynamics},
106
+ author={Zhengzheng Tang},
107
+ year={2025},
108
+ url={https://github.com/Brain2nd/NeuronSpark}
109
+ }
110
+ ```
111
+
112
+ ## Contact
113
+
114
+ - **Author**: Zhengzheng Tang
115
+ - **Email**: zztangbu@bu.edu
116
+ - **GitHub**: [Brain2nd/NeuronSpark](https://github.com/Brain2nd/NeuronSpark)
117
+
atomic_ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .selective_plif import SelectivePLIFNode
2
+ from .plif_node import PLIFNode
3
+ from .lateral_inhibition import LateralInhibition
4
+ from .snn_block import SNNBlock
5
+ from .snn_ffn import SNNFFN
6
+ from .snn_decoder_layer import SNNDecoderLayer
7
+ from .parallel_scan import hillis_steele_scan, linear_recurrence, plif_parallel_forward
8
+ from .fp16_codec import fp16_encode, fp16_decode
9
+ from .rms_norm import RMSNorm
atomic_ops/fp16_codec.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FP16 二进制编码/解码 — 模型边界操作(无可训练参数)。
3
+
4
+ IEEE 754 float16 位布局(K=16 时间步):
5
+ 时间步: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
6
+ 位: sign E4 E3 E2 E1 E0 M9 M8 M7 M6 M5 M4 M3 M2 M1 M0
7
+ 含义: 符号 ←── 指数(bias=15) ──→ ←────────── 尾数(隐含 1.xxx) ──────────→
8
+
9
+ 编码: embedding → IEEE 754 float16 位提取 → 16 帧二值 spike(detach,固定预处理)
10
+ 解码: 16 帧二值 spike → IEEE 754 位重建 → 连续值(可微分,梯度通过 surrogate grad 传播)
11
+ """
12
+
13
+ import torch
14
+ from torch import Tensor
15
+
16
+
17
+ def fp16_encode(emb: Tensor, K: int = 16) -> Tensor:
18
+ """FP16 二进制编码(模型边界操作,固定预处理)。
19
+
20
+ 将连续 embedding 转为 IEEE 754 float16 位模式,作为 SNN 的 spike 输入。
21
+
22
+ Args:
23
+ emb: (batch, seq_len, D) 连续 embedding
24
+ K: 时间步数(必须为 16,对应 float16 的 16 位)
25
+
26
+ Returns:
27
+ spike_seq: (seq_len*K, batch, D) 二值 {0, 1}, detached
28
+ """
29
+ batch, seq_len, D = emb.shape
30
+
31
+ # 转为 float16 获取 IEEE 754 位模式
32
+ # clamp 防止 overflow 产生 Inf(float16 最大值 65504)
33
+ emb_fp16 = emb.float().clamp(-65504.0, 65504.0).half()
34
+ bits_int = emb_fp16.view(torch.int16) # (batch, seq_len, D)
35
+
36
+ # 提取 16 位(MSB first: sign, exponent, mantissa)
37
+ shifts = torch.arange(15, -1, -1, device=emb.device) # [15, 14, ..., 0]
38
+ # bits_int: (batch, seq_len, D) → unsqueeze → (batch, seq_len, 1, D)
39
+ # shifts: (K,) → view → (1, 1, K, 1)
40
+ bits = ((bits_int.unsqueeze(2) >> shifts.view(1, 1, K, 1)) & 1) # (batch, seq_len, K, D)
41
+
42
+ # 转为计算 dtype 并 detach(编码不参与梯度)
43
+ bits = bits.to(emb.dtype).detach()
44
+
45
+ # reshape → (seq_len*K, batch, D)
46
+ return bits.reshape(batch, seq_len * K, D).permute(1, 0, 2).contiguous()
47
+
48
+
49
+ def fp16_decode(spikes: Tensor, seq_len: int, K: int = 16) -> Tensor:
50
+ """FP16 精确位解码:从 16 个二值 spike 重建 float16 值。
51
+
52
+ fp16_encode 的精确逆操作。全程可微分——梯度通过 IEEE 754 重建公式
53
+ 传到每个 spike 输出,再经 surrogate gradient 传入 SNN。
54
+
55
+ IEEE 754 float16 重建:
56
+ Normal (exp > 0): (-1)^sign * 2^(exp - 15) * (1 + mant_frac)
57
+ Subnormal (exp = 0): (-1)^sign * 2^(-14) * mant_frac
58
+ 其中 mant_frac = Σ mant_bit_i * 2^{-(i+1)}, i=0..9
59
+
60
+ Args:
61
+ spikes: (seq_len*K, batch, D) 二值 {0, 1}(输出神经元的 spike)
62
+ seq_len: token 序列长度
63
+ K: 时间步数(= 16)
64
+
65
+ Returns:
66
+ decoded: (batch, seq_len, D) 连续值
67
+ """
68
+ batch, D = spikes.shape[1], spikes.shape[2]
69
+
70
+ # (seq_len*K, batch, D) → (batch, seq_len, K, D)
71
+ s = spikes.permute(1, 0, 2).reshape(batch, seq_len, K, D)
72
+
73
+ # ---- Sign: bit 0 ----
74
+ sign = 1.0 - 2.0 * s[:, :, 0, :] # +1 or -1
75
+
76
+ # ---- Exponent: bits 1-5, 加权求和 → 整数 0~31 ----
77
+ exp_weights = torch.tensor(
78
+ [16.0, 8.0, 4.0, 2.0, 1.0],
79
+ device=spikes.device, dtype=spikes.dtype,
80
+ )
81
+ exp_val = (s[:, :, 1:6, :] * exp_weights.view(1, 1, 5, 1)).sum(dim=2)
82
+
83
+ # ---- Mantissa fraction: bits 6-15, 加权求和 → [0, 1) ----
84
+ mant_weights = torch.tensor(
85
+ [2.0 ** (-i) for i in range(1, 11)],
86
+ device=spikes.device, dtype=spikes.dtype,
87
+ )
88
+ mant_frac = (s[:, :, 6:, :] * mant_weights.view(1, 1, 10, 1)).sum(dim=2)
89
+
90
+ # ---- IEEE 754 重建 ----
91
+ # Normal: (-1)^s * 2^(exp-15) * (1 + mant_frac)
92
+ # Subnormal: (-1)^s * 2^(-14) * mant_frac
93
+ is_normal = (exp_val > 0)
94
+
95
+ normal_val = sign * torch.exp2(exp_val - 15.0) * (1.0 + mant_frac)
96
+ subnormal_val = sign * (2.0 ** -14) * mant_frac
97
+
98
+ return torch.where(is_normal, normal_val, subnormal_val)
atomic_ops/lateral_inhibition.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LateralInhibition: 侧抑制归一化(Divisive Normalization)
3
+
4
+ 神经科学基础:
5
+ Carandini & Heeger (2012) "Normalization as a canonical neural computation"
6
+ 侧抑制是大脑中最基本的计算原语之一:兴奋性神经元的活动通过抑制性中间神经元池
7
+ 反馈调节,实现增益控制(gain control)。
8
+
9
+ SNN 机制:
10
+ 1. 兴奋性群体活动: activity_i = h_i²
11
+ 2. 抑制性中间神经元池: pool = mean(activity) = mean(h²)
12
+ 3. 分裂抑制 (shunting inhibition): h_norm = h / sqrt(pool + ε)
13
+ 4. 增益调制 (gain modulation): output = gain · h_norm
14
+
15
+ 替换 RMSNorm:数学操作等价,但在 SNN 框架中有明确的神经科学对应——
16
+ RMSNorm 是 divisive normalization 的特例。
17
+
18
+ Triton fused kernel:
19
+ - 前向: {mean(h²), rsqrt, element-wise mul} → 1 kernel launch
20
+ - 反向: {recompute norm, grad_gain, grad_h} → 1 kernel launch
21
+ - 每行 (D dim) 一个 block,行间并行
22
+ """
23
+
24
+ import os
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ from spikingjelly.activation_based import base
29
+
30
+
31
+ # ============================================================
32
+ # Triton fused kernels
33
+ # ============================================================
34
+
35
+ _SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas'
36
+ if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ:
37
+ os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS
38
+
39
+ _HAS_TRITON = False
40
+ try:
41
+ import triton
42
+ import triton.language as tl
43
+ _HAS_TRITON = True
44
+ except ImportError:
45
+ pass
46
+
47
+
48
+ if _HAS_TRITON:
49
+
50
+ @triton.jit
51
+ def _li_fwd_kernel(
52
+ X_ptr, GAIN_ptr, OUT_ptr,
53
+ stride_row,
54
+ D: tl.constexpr,
55
+ eps: tl.constexpr,
56
+ BLOCK_D: tl.constexpr,
57
+ ):
58
+ """Forward: out = x * rsqrt(mean(x²) + eps) * gain
59
+
60
+ Grid: (num_rows,). Each program processes one row of D elements.
61
+ Computation in float32; storage in input dtype.
62
+ """
63
+ row = tl.program_id(0)
64
+ cols = tl.arange(0, BLOCK_D)
65
+ mask = cols < D
66
+ off = row * stride_row + cols
67
+
68
+ # Load in float32
69
+ x = tl.load(X_ptr + off, mask=mask, other=0.0).to(tl.float32)
70
+ gain = tl.load(GAIN_ptr + cols, mask=mask, other=0.0).to(tl.float32)
71
+
72
+ # Inhibitory pool: population activity
73
+ variance = tl.sum(x * x, axis=0) / D
74
+ rrms = 1.0 / tl.sqrt(variance + eps)
75
+
76
+ # Divisive inhibition + gain modulation
77
+ out = x * rrms * gain
78
+
79
+ tl.store(OUT_ptr + off, out, mask=mask)
80
+
81
+ @triton.jit
82
+ def _li_bwd_kernel(
83
+ DOUT_ptr, X_ptr, GAIN_ptr,
84
+ DX_ptr, DGAIN_ptr,
85
+ stride_row,
86
+ D: tl.constexpr,
87
+ eps: tl.constexpr,
88
+ BLOCK_D: tl.constexpr,
89
+ ):
90
+ """Backward: grad_x, grad_gain (per-row, reduced externally).
91
+
92
+ Grid: (num_rows,).
93
+ d_x = rrms * (d_out * gain - x_hat * mean(d_out * gain * x_hat))
94
+ d_gain_row = d_out * x_hat (sum across rows done outside kernel)
95
+ """
96
+ row = tl.program_id(0)
97
+ cols = tl.arange(0, BLOCK_D)
98
+ mask = cols < D
99
+ off = row * stride_row + cols
100
+
101
+ dout = tl.load(DOUT_ptr + off, mask=mask, other=0.0).to(tl.float32)
102
+ x = tl.load(X_ptr + off, mask=mask, other=0.0).to(tl.float32)
103
+ gain = tl.load(GAIN_ptr + cols, mask=mask, other=0.0).to(tl.float32)
104
+
105
+ # Recompute forward (avoid saving intermediate tensors)
106
+ variance = tl.sum(x * x, axis=0) / D
107
+ rrms = 1.0 / tl.sqrt(variance + eps)
108
+ x_hat = x * rrms
109
+
110
+ # grad_gain (per-row contribution)
111
+ dgain = dout * x_hat
112
+ tl.store(DGAIN_ptr + off, dgain, mask=mask)
113
+
114
+ # grad_x: rrms * (dout*gain - x_hat * mean(dout*gain*x_hat))
115
+ dout_gain = dout * gain
116
+ dot = tl.sum(dout_gain * x_hat, axis=0) / D
117
+ dx = (dout_gain - x_hat * dot) * rrms
118
+
119
+ tl.store(DX_ptr + off, dx, mask=mask)
120
+
121
+
122
+ class _LateralInhibitionTriton(torch.autograd.Function):
123
+ """Triton-accelerated lateral inhibition (divisive normalization)."""
124
+
125
+ @staticmethod
126
+ def forward(ctx, x, gain, eps):
127
+ orig_shape = x.shape
128
+ D = x.shape[-1]
129
+ x_2d = x.reshape(-1, D).contiguous()
130
+ N = x_2d.shape[0]
131
+
132
+ out = torch.empty_like(x_2d)
133
+
134
+ BLOCK_D = triton.next_power_of_2(D)
135
+ _li_fwd_kernel[(N,)](
136
+ x_2d, gain, out,
137
+ x_2d.stride(0),
138
+ D=D, eps=eps, BLOCK_D=BLOCK_D,
139
+ )
140
+
141
+ ctx.save_for_backward(x_2d, gain)
142
+ ctx.eps = eps
143
+ ctx.orig_shape = orig_shape
144
+ ctx.N = N
145
+ ctx.D = D
146
+
147
+ return out.reshape(orig_shape)
148
+
149
+ @staticmethod
150
+ def backward(ctx, grad_output):
151
+ x_2d, gain = ctx.saved_tensors
152
+ D = ctx.D
153
+ N = ctx.N
154
+
155
+ grad_2d = grad_output.reshape(N, D).contiguous()
156
+
157
+ dx = torch.empty_like(x_2d)
158
+ dgain_rows = torch.empty_like(x_2d)
159
+
160
+ BLOCK_D = triton.next_power_of_2(D)
161
+ _li_bwd_kernel[(N,)](
162
+ grad_2d, x_2d, gain,
163
+ dx, dgain_rows,
164
+ x_2d.stride(0),
165
+ D=D, eps=ctx.eps, BLOCK_D=BLOCK_D,
166
+ )
167
+
168
+ # Reduce per-row dgain across all rows
169
+ dgain = dgain_rows.sum(dim=0)
170
+
171
+ return dx.reshape(ctx.orig_shape), dgain, None
172
+
173
+
174
+ # ============================================================
175
+ # Public module
176
+ # ============================================================
177
+
178
+ class LateralInhibition(base.MemoryModule):
179
+ """
180
+ 侧抑制归一化层(Divisive Normalization)。
181
+
182
+ 通过抑制性中间神经元池实现增益控制。
183
+
184
+ 数学:
185
+ pool = mean(h², dim=-1) # 抑制性池:群体活动水平
186
+ h_norm = h / sqrt(pool + ε) # 分裂抑制 (shunting inhibition)
187
+ output = gain · h_norm # 增益调制 (gain modulation)
188
+
189
+ 等价于 RMSNorm,但在 SNN 框架中对应 divisive normalization
190
+ (Carandini & Heeger, 2012),是神经科学中最基本的计算原语之一。
191
+
192
+ CUDA: Triton fused kernel(前向+反向各 1 次 launch)
193
+ CPU: PyTorch fallback
194
+
195
+ Args:
196
+ dim: 特征维度(D)
197
+ eps: 数值稳定性
198
+ """
199
+
200
+ def __init__(self, dim: int, eps: float = 1e-6):
201
+ super().__init__()
202
+ self.gain = nn.Parameter(torch.ones(dim))
203
+ self.eps = eps
204
+ self.dim = dim
205
+
206
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
207
+ if _HAS_TRITON and h.is_cuda:
208
+ return _LateralInhibitionTriton.apply(h, self.gain, self.eps)
209
+ # PyTorch fallback
210
+ variance = h.pow(2).mean(-1, keepdim=True)
211
+ h_norm = h * torch.rsqrt(variance + self.eps)
212
+ return self.gain * h_norm
213
+
214
+ def extra_repr(self):
215
+ return f'dim={self.dim}, eps={self.eps}'
atomic_ops/parallel_scan.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parallel Scan 工具函数:SNN 线性递推的高效并行求解
3
+
4
+ 实现三层后端:
5
+ 1. Fused PLIF kernel(默认,CUDA + Sigmoid surrogate):
6
+ 单 kernel 完成 PLIF 前向(scan + spike + soft reset)和反向(surrogate gradient)
7
+ · per-element beta/v_th: _fused_plif_fwd_kernel / _fused_plif_bwd_kernel
8
+ · row-param beta/v_th: _fused_plif_fwd_rowparam_kernel / _fused_plif_bwd_rowparam_kernel
9
+ 2. Triton linear_recurrence(CUDA,非 Sigmoid 或无 surrogate):
10
+ 列级并行 scan,O(K) 工作量,1 次 kernel launch
11
+ 3. Hillis-Steele parallel scan(CPU 回退):O(K log K) 工作量
12
+
13
+ 线性递推:
14
+ V[k] = a[k] * V[k-1] + b[k], V[-1] = v_init
15
+
16
+ PLIF 神经元动力学:
17
+ V_pre[k] = beta[k] * V_post[k-1] + u[k]
18
+ s[k] = Θ(V_pre[k] - v_th[k])
19
+ V_post[k] = V_pre[k] - v_th[k] * s[k]
20
+
21
+ 数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
22
+ """
23
+
24
+ import os
25
+ import torch
26
+
27
+
28
+ # ============================================================
29
+ # Triton fused recurrence kernels
30
+ # ============================================================
31
+
32
+ # DGX Spark (GB10, sm_121a): Triton 3.5.1 自带 ptxas 不支持 sm_121a,
33
+ # 需要使用系统 CUDA 13.0 的 ptxas
34
+ _SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas'
35
+ if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ:
36
+ os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS
37
+
38
+ _HAS_TRITON = False
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+ _HAS_TRITON = True
43
+ except ImportError:
44
+ pass
45
+
46
+ if _HAS_TRITON:
47
+
48
+ @triton.jit
49
+ def _fwd_recurrence_kernel(
50
+ A_ptr, B_ptr, INIT_ptr, OUT_ptr,
51
+ K, num_cols,
52
+ BLOCK: tl.constexpr,
53
+ ):
54
+ """Forward: V[k] = A[k]*V[k-1] + B[k], V[-1] = init.
55
+
56
+ Grid: (ceil(num_cols / BLOCK),)
57
+ Each program processes BLOCK columns across all K sequential steps.
58
+ Accumulation in float32; storage in input dtype.
59
+ """
60
+ pid = tl.program_id(0)
61
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
62
+ mask = cols < num_cols
63
+
64
+ v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
65
+
66
+ for k in range(K):
67
+ off = k * num_cols + cols
68
+ a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32)
69
+ b = tl.load(B_ptr + off, mask=mask, other=0.0).to(tl.float32)
70
+ v = a * v + b
71
+ tl.store(OUT_ptr + off, v, mask=mask)
72
+
73
+ @triton.jit
74
+ def _bwd_recurrence_kernel(
75
+ A_ptr, V_ptr, INIT_ptr, GRAD_OUT_ptr,
76
+ GRAD_A_ptr, GRAD_B_ptr, GRAD_INIT_ptr,
77
+ K, num_cols,
78
+ BLOCK: tl.constexpr,
79
+ ):
80
+ """Backward for V[k] = A[k]*V[k-1] + B[k].
81
+
82
+ Reverse accumulation (k from K-1 down to 0):
83
+ g = 0
84
+ for k = K-1, ..., 0:
85
+ g += grad_out[k]
86
+ grad_B[k] = g
87
+ grad_A[k] = g * V[k-1] (V[-1] = init)
88
+ g = g * A[k]
89
+ grad_init = g
90
+ """
91
+ pid = tl.program_id(0)
92
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
93
+ mask = cols < num_cols
94
+
95
+ g = tl.zeros([BLOCK], dtype=tl.float32)
96
+
97
+ for k_rev in range(K):
98
+ k = K - 1 - k_rev
99
+ off = k * num_cols + cols
100
+
101
+ dV = tl.load(GRAD_OUT_ptr + off, mask=mask, other=0.0).to(tl.float32)
102
+ g = g + dV
103
+
104
+ tl.store(GRAD_B_ptr + off, g, mask=mask)
105
+
106
+ if k > 0:
107
+ v_prev = tl.load(
108
+ V_ptr + (k - 1) * num_cols + cols,
109
+ mask=mask, other=0.0,
110
+ ).to(tl.float32)
111
+ else:
112
+ v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
113
+ tl.store(GRAD_A_ptr + off, g * v_prev, mask=mask)
114
+
115
+ a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32)
116
+ g = g * a
117
+
118
+ tl.store(GRAD_INIT_ptr + cols, g, mask=mask)
119
+
120
+ class _TritonLinearRecurrence(torch.autograd.Function):
121
+ """Fused Triton linear recurrence: V[k] = A[k]*V[k-1] + B[k]."""
122
+
123
+ _BLOCK = 128
124
+
125
+ @staticmethod
126
+ def forward(ctx, beta, u, v_init):
127
+ beta_c = beta.contiguous()
128
+ u_c = u.contiguous()
129
+ v_init_c = v_init.contiguous()
130
+
131
+ K = beta_c.shape[0]
132
+ num_cols = beta_c[0].numel()
133
+ V = torch.empty_like(u_c)
134
+
135
+ BLOCK = _TritonLinearRecurrence._BLOCK
136
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
137
+
138
+ _fwd_recurrence_kernel[grid](
139
+ beta_c, u_c, v_init_c, V,
140
+ K, num_cols,
141
+ BLOCK=BLOCK,
142
+ )
143
+
144
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
145
+ ctx.save_for_backward(beta_c, V, v_init_c)
146
+ ctx.K = K
147
+ ctx.num_cols = num_cols
148
+
149
+ return V
150
+
151
+ @staticmethod
152
+ def backward(ctx, grad_V):
153
+ beta, V, v_init = ctx.saved_tensors
154
+ grad_V_c = grad_V.contiguous()
155
+
156
+ K = ctx.K
157
+ num_cols = ctx.num_cols
158
+
159
+ grad_beta = torch.empty_like(beta)
160
+ grad_u = torch.empty_like(beta)
161
+ grad_v_init = torch.empty_like(v_init)
162
+
163
+ BLOCK = _TritonLinearRecurrence._BLOCK
164
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
165
+
166
+ _bwd_recurrence_kernel[grid](
167
+ beta, V, v_init, grad_V_c,
168
+ grad_beta, grad_u, grad_v_init,
169
+ K, num_cols,
170
+ BLOCK=BLOCK,
171
+ )
172
+
173
+ return grad_beta, grad_u, grad_v_init
174
+
175
+ # ============================================================
176
+ # Fused PLIF forward/backward kernels
177
+ # ============================================================
178
+
179
+ @triton.jit
180
+ def _fused_plif_fwd_kernel(
181
+ BETA_ptr, U_ptr, VTH_ptr, INIT_ptr,
182
+ SPIKE_ptr, VPOST_ptr,
183
+ K, num_cols,
184
+ BLOCK: tl.constexpr,
185
+ ):
186
+ """Fused PLIF forward: single-pass sequential scan with inline spike + soft reset.
187
+
188
+ Exact computation — sequential scan IS the ground truth.
189
+ Replaces the 3-phase approach (linear scan + spike iteration + correction).
190
+
191
+ Per column (parallel across batch*D):
192
+ v = v_init
193
+ for k = 0..K-1:
194
+ v_pre = beta[k]*v + u[k]
195
+ spike[k] = Θ(v_pre - v_th[k])
196
+ v = v_pre - v_th[k]*spike[k]
197
+ """
198
+ pid = tl.program_id(0)
199
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
200
+ mask = cols < num_cols
201
+
202
+ v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
203
+
204
+ for k in range(K):
205
+ off = k * num_cols + cols
206
+ beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32)
207
+ u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32)
208
+ vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32)
209
+
210
+ v_pre = beta * v + u
211
+ spike = tl.where(v_pre >= vth, 1.0, 0.0)
212
+ v = v_pre - vth * spike # soft reset
213
+
214
+ tl.store(SPIKE_ptr + off, spike, mask=mask)
215
+ tl.store(VPOST_ptr + off, v, mask=mask)
216
+
217
+ @triton.jit
218
+ def _fused_plif_bwd_kernel(
219
+ BETA_ptr, VTH_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr,
220
+ GRAD_SPIKE_ptr, GRAD_VPOST_ptr,
221
+ GRAD_BETA_ptr, GRAD_U_ptr, GRAD_VTH_ptr, GRAD_INIT_ptr,
222
+ K, num_cols, ALPHA,
223
+ BLOCK: tl.constexpr,
224
+ ):
225
+ """Fused PLIF backward: single reverse pass with Sigmoid surrogate gradient.
226
+
227
+ V_pre[k] = V_post[k] + v_th[k]*spike[k] (reconstructed)
228
+ surrogate_grad(x) = alpha * sigmoid(alpha*x) * (1 - sigmoid(alpha*x))
229
+ where x = V_pre[k] - v_th[k] = V_post[k] - v_th[k]*(1 - spike[k])
230
+
231
+ Reverse accumulation:
232
+ acc = 0
233
+ for k = K-1 downto 0:
234
+ total_gV = grad_V_post[k] + acc
235
+ sg = surrogate_grad(V_pre[k] - v_th[k])
236
+ grad_v_pre = grad_spike[k]*sg + total_gV
237
+ grad_beta[k] = grad_v_pre * V_post[k-1]
238
+ grad_u[k] = grad_v_pre
239
+ grad_v_th[k] = -grad_spike[k]*sg - total_gV*spike[k]
240
+ acc = grad_v_pre * beta[k]
241
+ grad_v_init = acc
242
+ """
243
+ pid = tl.program_id(0)
244
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
245
+ mask = cols < num_cols
246
+
247
+ acc = tl.zeros([BLOCK], dtype=tl.float32)
248
+
249
+ for k_rev in range(K):
250
+ k = K - 1 - k_rev
251
+ off = k * num_cols + cols
252
+
253
+ beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32)
254
+ vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32)
255
+ v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
256
+ spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
257
+
258
+ g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
259
+ g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
260
+
261
+ # V_post[k-1]
262
+ if k > 0:
263
+ v_prev = tl.load(
264
+ VPOST_ptr + (k - 1) * num_cols + cols,
265
+ mask=mask, other=0.0,
266
+ ).to(tl.float32)
267
+ else:
268
+ v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
269
+
270
+ # Sigmoid surrogate gradient
271
+ x = v_post - vth * (1.0 - spike) # = V_pre - v_th
272
+ neg_ax = -ALPHA * x
273
+ neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax) # prevent exp overflow
274
+ sig = 1.0 / (1.0 + tl.exp(neg_ax))
275
+ sg = ALPHA * sig * (1.0 - sig)
276
+
277
+ total_gV = g_V + acc
278
+ grad_v_pre = g_s * sg + total_gV
279
+
280
+ tl.store(GRAD_BETA_ptr + off, grad_v_pre * v_prev, mask=mask)
281
+ tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask)
282
+ tl.store(GRAD_VTH_ptr + off, -g_s * sg - total_gV * spike, mask=mask)
283
+
284
+ acc = grad_v_pre * beta
285
+
286
+ tl.store(GRAD_INIT_ptr + cols, acc, mask=mask)
287
+
288
+ # ============================================================
289
+ # Fused PLIF kernels with row-parameter beta/v_th
290
+ # (constant across K steps — e.g., ParametricLIFNode scalars)
291
+ # ============================================================
292
+
293
+ @triton.jit
294
+ def _fused_plif_fwd_rowparam_kernel(
295
+ BETA_ROW_ptr, U_ptr, VTH_ROW_ptr, INIT_ptr,
296
+ SPIKE_ptr, VPOST_ptr,
297
+ K, num_cols,
298
+ BLOCK: tl.constexpr,
299
+ ):
300
+ """Fused PLIF forward with row-parameter beta and v_th.
301
+
302
+ beta and v_th are (*shape) — constant across K steps, loaded once into registers.
303
+ Reduces global memory reads from 3 per step (beta, u, v_th) to 1 (u only).
304
+ """
305
+ pid = tl.program_id(0)
306
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
307
+ mask = cols < num_cols
308
+
309
+ v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
310
+ beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
311
+ vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
312
+
313
+ for k in range(K):
314
+ off = k * num_cols + cols
315
+ u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32)
316
+
317
+ v_pre = beta * v + u
318
+ spike = tl.where(v_pre >= vth, 1.0, 0.0)
319
+ v = v_pre - vth * spike
320
+
321
+ tl.store(SPIKE_ptr + off, spike, mask=mask)
322
+ tl.store(VPOST_ptr + off, v, mask=mask)
323
+
324
+ @triton.jit
325
+ def _fused_plif_bwd_rowparam_kernel(
326
+ BETA_ROW_ptr, VTH_ROW_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr,
327
+ GRAD_SPIKE_ptr, GRAD_VPOST_ptr,
328
+ GRAD_BETA_ROW_ptr, GRAD_U_ptr, GRAD_VTH_ROW_ptr, GRAD_INIT_ptr,
329
+ K, num_cols, ALPHA,
330
+ BLOCK: tl.constexpr,
331
+ ):
332
+ """Fused PLIF backward with row-parameter beta/v_th.
333
+
334
+ Gradients for beta and v_th are accumulated over K steps (reduction in registers).
335
+ Returns grad_beta_row (*shape) and grad_v_th_row (*shape) instead of per-step gradients.
336
+ """
337
+ pid = tl.program_id(0)
338
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
339
+ mask = cols < num_cols
340
+
341
+ beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
342
+ vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
343
+
344
+ acc = tl.zeros([BLOCK], dtype=tl.float32)
345
+ acc_grad_beta = tl.zeros([BLOCK], dtype=tl.float32)
346
+ acc_grad_vth = tl.zeros([BLOCK], dtype=tl.float32)
347
+
348
+ for k_rev in range(K):
349
+ k = K - 1 - k_rev
350
+ off = k * num_cols + cols
351
+
352
+ v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
353
+ spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
354
+
355
+ g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
356
+ g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
357
+
358
+ if k > 0:
359
+ v_prev = tl.load(
360
+ VPOST_ptr + (k - 1) * num_cols + cols,
361
+ mask=mask, other=0.0,
362
+ ).to(tl.float32)
363
+ else:
364
+ v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
365
+
366
+ # Sigmoid surrogate gradient
367
+ x = v_post - vth * (1.0 - spike)
368
+ neg_ax = -ALPHA * x
369
+ neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax)
370
+ sig = 1.0 / (1.0 + tl.exp(neg_ax))
371
+ sg = ALPHA * sig * (1.0 - sig)
372
+
373
+ total_gV = g_V + acc
374
+ grad_v_pre = g_s * sg + total_gV
375
+
376
+ tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask)
377
+
378
+ # Accumulate gradients for row parameters (reduction over K in registers)
379
+ acc_grad_beta += grad_v_pre * v_prev
380
+ acc_grad_vth += -g_s * sg - total_gV * spike
381
+
382
+ acc = grad_v_pre * beta
383
+
384
+ tl.store(GRAD_INIT_ptr + cols, acc, mask=mask)
385
+ tl.store(GRAD_BETA_ROW_ptr + cols, acc_grad_beta, mask=mask)
386
+ tl.store(GRAD_VTH_ROW_ptr + cols, acc_grad_vth, mask=mask)
387
+
388
+ class _TritonPLIFRowParamForward(torch.autograd.Function):
389
+ """Fused Triton PLIF with row-parameter beta/v_th.
390
+
391
+ For neurons with constant beta/v_th across K steps (ParametricLIFNode).
392
+ Eliminates expand+contiguous for beta/v_th tensors, reduces memory I/O by ~40%.
393
+ """
394
+
395
+ _BLOCK = 128
396
+
397
+ @staticmethod
398
+ def forward(ctx, beta_row, u, v_th_row, v_init, alpha):
399
+ beta_row_c = beta_row.contiguous()
400
+ u_c = u.contiguous()
401
+ v_th_row_c = v_th_row.contiguous()
402
+ v_init_c = v_init.contiguous()
403
+
404
+ K = u_c.shape[0]
405
+ num_cols = u_c[0].numel()
406
+
407
+ spike = torch.empty_like(u_c)
408
+ V_post = torch.empty_like(u_c)
409
+
410
+ BLOCK = _TritonPLIFRowParamForward._BLOCK
411
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
412
+
413
+ _fused_plif_fwd_rowparam_kernel[grid](
414
+ beta_row_c, u_c, v_th_row_c, v_init_c,
415
+ spike, V_post,
416
+ K, num_cols,
417
+ BLOCK=BLOCK,
418
+ )
419
+
420
+ if any(ctx.needs_input_grad[:4]):
421
+ ctx.save_for_backward(beta_row_c, v_th_row_c, v_init_c, V_post, spike)
422
+ ctx.K = K
423
+ ctx.num_cols = num_cols
424
+ ctx.alpha = alpha
425
+
426
+ return spike, V_post
427
+
428
+ @staticmethod
429
+ def backward(ctx, grad_spike, grad_V_post):
430
+ beta_row, v_th_row, v_init, V_post, spike = ctx.saved_tensors
431
+ K = ctx.K
432
+ num_cols = ctx.num_cols
433
+ alpha = ctx.alpha
434
+
435
+ if grad_spike is None:
436
+ grad_spike = torch.zeros_like(spike)
437
+ if grad_V_post is None:
438
+ grad_V_post = torch.zeros_like(V_post)
439
+
440
+ grad_spike_c = grad_spike.contiguous()
441
+ grad_V_post_c = grad_V_post.contiguous()
442
+
443
+ grad_beta_row = torch.empty_like(beta_row)
444
+ grad_u = torch.empty_like(V_post)
445
+ grad_v_th_row = torch.empty_like(v_th_row)
446
+ grad_v_init = torch.empty_like(v_init)
447
+
448
+ BLOCK = _TritonPLIFRowParamForward._BLOCK
449
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
450
+
451
+ _fused_plif_bwd_rowparam_kernel[grid](
452
+ beta_row, v_th_row, v_init, V_post, spike,
453
+ grad_spike_c, grad_V_post_c,
454
+ grad_beta_row, grad_u, grad_v_th_row, grad_v_init,
455
+ K, num_cols, float(alpha),
456
+ BLOCK=BLOCK,
457
+ )
458
+
459
+ return grad_beta_row, grad_u, grad_v_th_row, grad_v_init, None
460
+
461
+ class _TritonPLIFForward(torch.autograd.Function):
462
+ """Fused Triton PLIF forward + backward.
463
+
464
+ Single-pass sequential scan replaces the 3-phase approach:
465
+ Phase 1 (linear scan) + Phase 2 (spike iteration) + Phase 3 (correction)
466
+ → 1 fused kernel with inline spike detection + soft reset
467
+
468
+ Advantages:
469
+ - 1 kernel launch (vs 3-4 launches + ~10 element-wise ops)
470
+ - Exact computation (no iteration convergence issues)
471
+ - Less memory (no intermediate V_L, delta_S, delta_S_prev)
472
+ - Higher precision (fp32 accumulation, no bf16 intermediate store/load)
473
+ """
474
+
475
+ _BLOCK = 128
476
+
477
+ @staticmethod
478
+ def forward(ctx, beta, u, v_th, v_init, alpha):
479
+ beta_c = beta.contiguous()
480
+ u_c = u.contiguous()
481
+ v_th_c = v_th.contiguous()
482
+ v_init_c = v_init.contiguous()
483
+
484
+ K = beta_c.shape[0]
485
+ num_cols = beta_c[0].numel()
486
+
487
+ spike = torch.empty_like(u_c)
488
+ V_post = torch.empty_like(u_c)
489
+
490
+ BLOCK = _TritonPLIFForward._BLOCK
491
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
492
+
493
+ _fused_plif_fwd_kernel[grid](
494
+ beta_c, u_c, v_th_c, v_init_c,
495
+ spike, V_post,
496
+ K, num_cols,
497
+ BLOCK=BLOCK,
498
+ )
499
+
500
+ if any(ctx.needs_input_grad[:4]):
501
+ ctx.save_for_backward(beta_c, v_th_c, v_init_c, V_post, spike)
502
+ ctx.K = K
503
+ ctx.num_cols = num_cols
504
+ ctx.alpha = alpha
505
+
506
+ return spike, V_post
507
+
508
+ @staticmethod
509
+ def backward(ctx, grad_spike, grad_V_post):
510
+ beta, v_th, v_init, V_post, spike = ctx.saved_tensors
511
+ K = ctx.K
512
+ num_cols = ctx.num_cols
513
+ alpha = ctx.alpha
514
+
515
+ if grad_spike is None:
516
+ grad_spike = torch.zeros_like(spike)
517
+ if grad_V_post is None:
518
+ grad_V_post = torch.zeros_like(V_post)
519
+
520
+ grad_spike_c = grad_spike.contiguous()
521
+ grad_V_post_c = grad_V_post.contiguous()
522
+
523
+ grad_beta = torch.empty_like(beta)
524
+ grad_u = torch.empty_like(beta)
525
+ grad_v_th = torch.empty_like(v_th)
526
+ grad_v_init = torch.empty_like(v_init)
527
+
528
+ BLOCK = _TritonPLIFForward._BLOCK
529
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
530
+
531
+ _fused_plif_bwd_kernel[grid](
532
+ beta, v_th, v_init, V_post, spike,
533
+ grad_spike_c, grad_V_post_c,
534
+ grad_beta, grad_u, grad_v_th, grad_v_init,
535
+ K, num_cols, float(alpha),
536
+ BLOCK=BLOCK,
537
+ )
538
+
539
+ return grad_beta, grad_u, grad_v_th, grad_v_init, None
540
+
541
+
542
+ # ============================================================
543
+ # Hillis-Steele parallel prefix scan (CPU fallback)
544
+ # ============================================================
545
+
546
+ def hillis_steele_scan(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
547
+ """
548
+ Hillis-Steele 并行前缀扫描:计算仿射映射序列的所有前缀复合。
549
+
550
+ 给定仿射映射 f_k(x) = a[k] * x + b[k], k = 0, ..., K-1,
551
+ 计算前缀复合 F_k = f_k ∘ f_{k-1} ∘ ... ∘ f_0,
552
+ 使得 V[k] = F_k(v_init) = A[k] * v_init + B[k]。
553
+
554
+ 复合规则: (a2, b2) ∘ (a1, b1) = (a2 * a1, a2 * b1 + b2)
555
+
556
+ 实现使用 torch.cat 重建张量(无原地操作),完全兼容 autograd。
557
+
558
+ Args:
559
+ a: (K, *shape) — 乘性系数(如 β)
560
+ b: (K, *shape) — 加性项(如 α·I)
561
+
562
+ Returns:
563
+ A: (K, *shape) — 前缀积 A[k] = ∏_{j=0}^{k} a[j]
564
+ B: (K, *shape) — 前缀和 B[k] 使得 V[k] = A[k] * v_init + B[k]
565
+
566
+ 并行深度: O(log K)
567
+ 工作量: O(K * log K)
568
+ """
569
+ K = a.shape[0]
570
+ A = a
571
+ B = b
572
+
573
+ d = 1
574
+ while d < K:
575
+ A_new_tail = A[d:] * A[:-d]
576
+ B_new_tail = A[d:] * B[:-d] + B[d:]
577
+
578
+ A = torch.cat([A[:d], A_new_tail], dim=0)
579
+ B = torch.cat([B[:d], B_new_tail], dim=0)
580
+
581
+ d *= 2
582
+
583
+ return A, B
584
+
585
+
586
+ # ============================================================
587
+ # Public API: linear_recurrence
588
+ # ============================================================
589
+
590
+ def linear_recurrence(beta: torch.Tensor, u: torch.Tensor, v_init: torch.Tensor) -> torch.Tensor:
591
+ """
592
+ 求解线性递推: V[k] = beta[k] * V[k-1] + u[k], V[-1] = v_init
593
+
594
+ CUDA 后端: Triton fused kernel(1 次 kernel launch,O(K) 工作量)
595
+ CPU 后端: Hillis-Steele parallel scan(O(K log K) 工作量)
596
+
597
+ Args:
598
+ beta: (K, *shape) — 衰减系数,值域 (0, 1)
599
+ u: (K, *shape) — 输入项
600
+ v_init: (*shape) — 初始状态
601
+
602
+ Returns:
603
+ V: (K, *shape) — 所有 K 步的状态
604
+ """
605
+ if _HAS_TRITON and beta.is_cuda:
606
+ return _TritonLinearRecurrence.apply(beta, u, v_init)
607
+ # CPU fallback
608
+ A, B = hillis_steele_scan(beta, u)
609
+ V = A * v_init.unsqueeze(0) + B
610
+ return V
611
+
612
+
613
+ # ============================================================
614
+ # PLIF parallel forward (with spike iteration)
615
+ # ============================================================
616
+
617
+ def plif_parallel_forward(
618
+ beta: torch.Tensor,
619
+ u: torch.Tensor,
620
+ v_th: torch.Tensor,
621
+ v_init: torch.Tensor,
622
+ max_iter: int = 3,
623
+ surrogate_function=None,
624
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
625
+ """
626
+ PLIF 神经元的并行前向传播(soft reset,surrogate gradient 兼容)。
627
+
628
+ 求解:
629
+ V_pre[k] = beta[k] * V_post[k-1] + u[k]
630
+ s[k] = Θ(V_pre[k] - v_th[k])
631
+ V_post[k] = V_pre[k] - v_th[k] * s[k]
632
+
633
+ 方法:
634
+ Phase 1: 线性轨迹 parallel scan(有梯度)
635
+ Phase 2: spike 不动点迭代(detach,确定离散 spike pattern)
636
+ Phase 3: 用 converged spike pattern 重算 V_post(有梯度),
637
+ surrogate_function(V_pre - v_th) 生成可微 spike 输出
638
+
639
+ Args:
640
+ beta: (K, *shape) — 衰减系数
641
+ u: (K, *shape) — 输入 α·I
642
+ v_th: (K, *shape) — 动态阈值
643
+ v_init: (*shape) — 初始膜电位
644
+ max_iter: spike 不动点迭代次数上限
645
+ surrogate_function: SpikingJelly surrogate gradient 函数(如 surrogate.Sigmoid(alpha=4.0))
646
+ None 时退化为硬阈值(无梯度)
647
+
648
+ Returns:
649
+ spike: (K, *shape) — spike 模式(有 surrogate gradient)
650
+ V_post: (K, *shape) — 发放后膜电位
651
+ V_pre: (K, *shape) — 发放前膜电位(fused path 返回 None)
652
+ """
653
+ # Fused Triton path: single-pass sequential scan (exact, no iteration)
654
+ # Replaces 3-phase approach with 1 kernel launch — ~3x faster forward, ~5x faster backward
655
+ if (_HAS_TRITON and beta.is_cuda and surrogate_function is not None
656
+ and hasattr(surrogate_function, 'alpha')
657
+ and type(surrogate_function).__name__ == 'Sigmoid'):
658
+ alpha = float(surrogate_function.alpha)
659
+ spike, V_post = _TritonPLIFForward.apply(beta, u, v_th, v_init, alpha)
660
+ return spike, V_post, None
661
+
662
+ # Fallback: 3-phase approach (CPU, non-Sigmoid surrogates, or no surrogate)
663
+ # Phase 1: 线性轨迹 V_L (假设从不发放)
664
+ V_L = linear_recurrence(beta, u, v_init) # (K, *shape)
665
+
666
+ # Phase 2: Spike 不动点迭代(全部 detach,不建立梯度图)
667
+ # 目的:确定哪些神经元在哪些步发放(离散决策)
668
+ with torch.no_grad():
669
+ V_L_det = V_L.detach()
670
+ beta_det = beta.detach()
671
+ v_th_det = v_th.detach()
672
+ v_init_det = v_init.detach() if isinstance(v_init, torch.Tensor) else v_init
673
+
674
+ spike_pattern = (V_L_det >= v_th_det).float()
675
+
676
+ for _ in range(max_iter - 1):
677
+ # 计算 ΔS: ΔS[k] = beta[k] * ΔS[k-1] + v_th[k] * s[k]
678
+ delta_S = linear_recurrence(
679
+ beta_det, v_th_det * spike_pattern,
680
+ torch.zeros_like(v_init_det) if isinstance(v_init_det, torch.Tensor)
681
+ else torch.zeros_like(V_L_det[0]),
682
+ )
683
+
684
+ # ΔS_prev = ΔS[k-1](位移一步)
685
+ delta_S_prev = torch.zeros_like(delta_S)
686
+ delta_S_prev[1:] = delta_S[:-1]
687
+
688
+ # V_pre = V_L - beta * ΔS_prev
689
+ V_pre_det = V_L_det - beta_det * delta_S_prev
690
+
691
+ # 更新 spike
692
+ spike_new = (V_pre_det >= v_th_det).float()
693
+
694
+ # 收敛检查
695
+ if torch.equal(spike_new, spike_pattern):
696
+ break
697
+ spike_pattern = spike_new
698
+
699
+ # Phase 3: 用 converged spike pattern 重算 V_post(有完整梯度)
700
+ # spike_pattern 是 detached 的,作为常数参与计算
701
+ # 梯度通过 u, v_th, beta, v_init 流动
702
+ u_eff = u - v_th * spike_pattern
703
+ V_post = linear_recurrence(beta, u_eff, v_init) # (K, *shape)
704
+
705
+ # 重建 V_pre(有梯度,用于 surrogate gradient)
706
+ V_post_prev = torch.zeros_like(V_post)
707
+ if isinstance(v_init, torch.Tensor):
708
+ V_post_prev[0] = v_init
709
+ V_post_prev[1:] = V_post[:-1]
710
+ V_pre = beta * V_post_prev + u
711
+
712
+ # 生成可微 spike 输出
713
+ if surrogate_function is not None:
714
+ # forward: Heaviside(V_pre - v_th), backward: surrogate gradient
715
+ spike = surrogate_function(V_pre - v_th)
716
+ else:
717
+ # 退化模式:硬阈值,无梯度
718
+ spike = (V_pre >= v_th).float()
719
+
720
+ return spike, V_post, V_pre
721
+
722
+
723
+ def plif_rowparam_forward(
724
+ beta_row: torch.Tensor,
725
+ u: torch.Tensor,
726
+ v_th_row: torch.Tensor,
727
+ v_init: torch.Tensor,
728
+ surrogate_function=None,
729
+ ) -> tuple[torch.Tensor, torch.Tensor]:
730
+ """
731
+ 行参数 PLIF 前向:beta 和 v_th 在 K 步中保持恒定。
732
+
733
+ 比 plif_parallel_forward 快 ~40%(省去 expand+contiguous,减少 2/3 显存读取)。
734
+ 用于 ParametricLIFNode(固定 beta/v_th)或合并多个固定参数神经元。
735
+
736
+ Args:
737
+ beta_row: (*shape) — 每列的衰减率(所有 K 步相同)
738
+ u: (K, *shape) — 每步输入
739
+ v_th_row: (*shape) — 每列的阈值(所有 K 步相同)
740
+ v_init: (*shape) — 初始膜电位
741
+ surrogate_function: surrogate gradient 函数
742
+
743
+ Returns:
744
+ spike: (K, *shape) — spike 模式
745
+ V_post: (K, *shape) — 发放后膜电位
746
+ """
747
+ if (_HAS_TRITON and u.is_cuda and surrogate_function is not None
748
+ and hasattr(surrogate_function, 'alpha')
749
+ and type(surrogate_function).__name__ == 'Sigmoid'):
750
+ alpha = float(surrogate_function.alpha)
751
+ spike, V_post = _TritonPLIFRowParamForward.apply(
752
+ beta_row, u, v_th_row, v_init, alpha,
753
+ )
754
+ return spike, V_post
755
+
756
+ # Fallback: expand to full (K, *shape) and use standard path
757
+ K = u.shape[0]
758
+ beta = beta_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous()
759
+ v_th = v_th_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous()
760
+ spike, V_post, _ = plif_parallel_forward(beta, u, v_th, v_init, surrogate_function=surrogate_function)
761
+ return spike, V_post
762
+
763
+
764
+ def plif_fixed_param_forward(
765
+ beta,
766
+ u: torch.Tensor,
767
+ v_th,
768
+ v_init: torch.Tensor,
769
+ max_iter: int = 3,
770
+ surrogate_function=None,
771
+ ) -> tuple[torch.Tensor, torch.Tensor]:
772
+ """
773
+ 固定参数 PLIF 神经元的并行前向(如输出神经元、FFN 神经元)。
774
+
775
+ ParametricLIFNode 方程: V[k] = beta * V[k-1] + (1-beta) * x[k]
776
+ 其中 beta = 1/(1+exp(w)), 可为 scalar tensor(保持梯度流向 w)。
777
+
778
+ scalar/0-dim beta 和 v_th 使用 row-param 内核(无需 expand 到 (K, *shape))。
779
+
780
+ Args:
781
+ beta: 衰减率 — scalar float、0-dim tensor 或 (K, *shape) tensor
782
+ u: (K, *shape) — 输入(已乘以 (1-beta))
783
+ v_th: 阈值 — scalar float、0-dim tensor 或 (K, *shape) tensor
784
+ v_init: (*shape) — 初始膜电位
785
+ max_iter: spike 迭代次数
786
+ surrogate_function: surrogate gradient 函数
787
+
788
+ Returns:
789
+ spike: (K, *shape) — spike 模式
790
+ V_post: (K, *shape) — 发放后膜电位
791
+ """
792
+ K = u.shape[0]
793
+ shape = u.shape[1:]
794
+
795
+ # Row-param fast path: beta 和 v_th 都是 scalar/0-dim → 扩展为 (*shape) 行向量
796
+ beta_is_scalar = isinstance(beta, torch.Tensor) and beta.dim() == 0
797
+ beta_is_float = not isinstance(beta, torch.Tensor)
798
+ vth_is_scalar = isinstance(v_th, torch.Tensor) and v_th.dim() == 0
799
+ vth_is_float = not isinstance(v_th, torch.Tensor)
800
+
801
+ if (beta_is_scalar or beta_is_float) and (vth_is_scalar or vth_is_float):
802
+ # Build row vectors (*shape)
803
+ if beta_is_scalar:
804
+ beta_row = beta.expand(*shape).contiguous()
805
+ else:
806
+ beta_row = torch.full(shape, beta, device=u.device, dtype=u.dtype)
807
+ if vth_is_scalar:
808
+ v_th_row = v_th.expand(*shape).contiguous()
809
+ else:
810
+ v_th_row = torch.full(shape, v_th, device=u.device, dtype=u.dtype)
811
+ return plif_rowparam_forward(beta_row, u, v_th_row, v_init, surrogate_function)
812
+
813
+ # Full-tensor path: expand to (K, *shape) if needed
814
+ if isinstance(beta, torch.Tensor):
815
+ if beta.dim() == 0:
816
+ beta = beta.expand(K, *shape).contiguous()
817
+ else:
818
+ beta = torch.full_like(u, beta)
819
+
820
+ if isinstance(v_th, torch.Tensor):
821
+ if v_th.dim() == 0:
822
+ v_th = v_th.expand(K, *shape).contiguous()
823
+ else:
824
+ v_th = torch.full_like(u, v_th)
825
+
826
+ spike, V_post, _ = plif_parallel_forward(
827
+ beta, u, v_th, v_init, max_iter, surrogate_function,
828
+ )
829
+ return spike, V_post
atomic_ops/plif_node.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PLIFNode: D 维固定参数 PLIF 神经元(设计文档 5.5 "普通 SNN 神经元")
3
+
4
+ 与 SelectivePLIFNode 的区别:
5
+ SelectivePLIF: β(t), α(t), V_th(t) 由输入每步动态计算(选择性记忆)
6
+ PLIFNode: β, V_th 为 D 维可学习参数,训练后固定(信号转换)
7
+
8
+ 每个维度有独立的可学习参数:
9
+ β_d = sigmoid(w_d): 时间常数(衰减率)
10
+ V_th_d: 发放阈值
11
+
12
+ 动力学(与 ParametricLIF 一致):
13
+ V[t] = β · V[t-1] + (1-β) · x[t]
14
+ s[t] = Θ(V[t] - V_th) (surrogate gradient)
15
+ V[t] -= V_th · s[t] (soft reset)
16
+ """
17
+
18
+ import math
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from spikingjelly.activation_based import base, surrogate
23
+
24
+
25
+ class PLIFNode(base.MemoryModule):
26
+ """
27
+ D 维固定参数 PLIF 神经元。
28
+
29
+ Args:
30
+ dim: 神经元数量(每个维度独立参数)
31
+ init_tau: 初始时间常数 τ(β = 1 - 1/τ)
32
+ v_threshold: 初始发放阈值
33
+ surrogate_function: surrogate gradient 函数
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ dim: int,
39
+ init_tau: float = 2.0,
40
+ v_threshold: float = 0.5,
41
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
42
+ ):
43
+ super().__init__()
44
+ # D 维可学习参数(随机初始化,每个维度独立)
45
+ # w: 控制 β=sigmoid(w),随机产生不同时间常数
46
+ # init_w ± 0.5 → β ∈ ~[sigmoid(w-0.5), sigmoid(w+0.5)]
47
+ # tau=2.0 时 w=0, β ∈ ~[0.38, 0.62]
48
+ init_w = -math.log(init_tau - 1.0)
49
+ self.w = nn.Parameter(torch.empty(dim).normal_(init_w, 0.5))
50
+ # v_th: 发放阈值,U[0.5x, 1.5x] 均匀分布产生维度间多样性
51
+ self.v_th = nn.Parameter(torch.empty(dim).uniform_(
52
+ v_threshold * 0.5, v_threshold * 1.5,
53
+ ))
54
+ self.surrogate_function = surrogate_function
55
+ # 膜电位状态(functional.reset_net 时重置为 0.)
56
+ self.register_memory('v', 0.)
57
+
58
+ @property
59
+ def beta(self):
60
+ """D 维衰减率 β = sigmoid(w),值域 (0, 1)。"""
61
+ return torch.sigmoid(self.w)
62
+
63
+ def forward(self, x):
64
+ """
65
+ 单步前向传播。
66
+
67
+ V[t] = β · V[t-1] + (1-β) · x[t], spike = Θ(V-V_th), soft reset。
68
+
69
+ Args:
70
+ x: 输入电流, shape (batch, dim)
71
+
72
+ Returns:
73
+ spike: 二值脉冲, shape (batch, dim), 值域 {0, 1}
74
+ """
75
+ if isinstance(self.v, float):
76
+ self.v = torch.zeros_like(x)
77
+ beta = self.beta
78
+ self.v = beta * self.v + (1.0 - beta) * x
79
+ spike = self.surrogate_function(self.v - self.v_th)
80
+ self.v = self.v - spike * self.v_th # soft reset
81
+ return spike
atomic_ops/rms_norm.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RMSNorm: 残差流分支归一化(Pre-LN 模式)。
3
+
4
+ 位置: h → RMSNorm → PLIFNode → SNN子层 → out_proj → 残差
5
+ 作用: 控制送入 PLIFNode 的输入 scale,防止残差流漂移/爆炸。
6
+ 仅归一化分支输入,残差流本身不被归一化。
7
+
8
+ 对标 Qwen3/LLaMA 的 Pre-LN RMSNorm。
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ class RMSNorm(nn.Module):
16
+ """Root Mean Square Layer Normalization.
17
+
18
+ x_norm = x / RMS(x) * weight
19
+ RMS(x) = sqrt(mean(x^2) + eps)
20
+
21
+ Args:
22
+ dim: 归一化维度
23
+ eps: 数值稳定性
24
+ """
25
+
26
+ def __init__(self, dim: int, eps: float = 1e-6):
27
+ super().__init__()
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+ self.eps = eps
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ input_dtype = x.dtype
33
+ x = x.float()
34
+ variance = x.pow(2).mean(-1, keepdim=True)
35
+ x = x * torch.rsqrt(variance + self.eps)
36
+ return (self.weight * x).to(input_dtype)
atomic_ops/selective_plif.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SelectivePLIFNode: 动态参数的 PLIF 神经元
3
+
4
+ 与标准 ParametricLIFNode 的区别:
5
+ - β(t), α(t), V_th(t) 作为外部参数每步传入,不是内部 nn.Parameter
6
+ - 本神经元无可训练参数,所有学习发生在 SNNBlock 的调制网络中
7
+ - 仅支持 step_mode='s'(单步模式)
8
+ - 仅支持 soft reset(v_reset=None)
9
+
10
+ 状态方程:
11
+ V[t] = β(t) · V[t-1] + α(t) · I[t]
12
+ s[t] = Θ(V[t] - V_th(t)) (surrogate gradient)
13
+ V[t] -= V_th(t) · s[t] (soft reset)
14
+ """
15
+
16
+ import torch
17
+ from spikingjelly.activation_based import neuron, surrogate
18
+
19
+
20
+ class SelectivePLIFNode(neuron.BaseNode):
21
+ """
22
+ 隐状态空间的核心神经元。
23
+
24
+ 接收外部动态计算的 β(t), α(t), V_th(t),执行:
25
+ charge → fire → soft reset
26
+
27
+ Args:
28
+ surrogate_function: surrogate gradient 函数,默认 Sigmoid(alpha=4.0)
29
+ detach_reset: 是否在 reset 时 detach spike,默认 False
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
35
+ detach_reset: bool = False,
36
+ ):
37
+ # v_threshold=1.0 是占位值,实际使用外部传入的 v_th
38
+ # v_reset=None 触发 soft reset 模式,register_memory('v', 0.)
39
+ super().__init__(
40
+ v_threshold=1.0,
41
+ v_reset=None,
42
+ surrogate_function=surrogate_function,
43
+ detach_reset=detach_reset,
44
+ step_mode='s',
45
+ backend='torch',
46
+ store_v_seq=False,
47
+ )
48
+
49
+ def single_step_forward(
50
+ self,
51
+ x: torch.Tensor,
52
+ beta: torch.Tensor,
53
+ alpha: torch.Tensor,
54
+ v_th: torch.Tensor,
55
+ ) -> torch.Tensor:
56
+ """
57
+ 单步前向传播。
58
+
59
+ Args:
60
+ x: 输入电流 I[t], shape (batch, D*N)
61
+ beta: 衰减率 β(t), shape (batch, D*N), 值域 (0, 1)
62
+ alpha: 写入增益 α(t), shape (batch, D*N), 值域 R+
63
+ v_th: 动态阈值 V_th(t), shape (batch, D*N), 值域 R+
64
+
65
+ Returns:
66
+ spike: 二值脉冲 s[t], shape (batch, D*N), 值域 {0, 1}
67
+ """
68
+ # Phase 0: 首步将 v 从 float 扩展为与输入同形的张量
69
+ self.v_float_to_tensor(x)
70
+
71
+ # Phase 1: Charge — 膜电位更新
72
+ # V[t] = β(t) · V[t-1] + α(t) · I[t]
73
+ self.v = beta * self.v + alpha * x
74
+
75
+ # Phase 2: Fire — 使用动态 v_th(不是 self.v_threshold)
76
+ # spike = Heaviside(V[t] - V_th(t)),反向用 surrogate gradient
77
+ spike = self.surrogate_function(self.v - v_th)
78
+
79
+ # Phase 3: Soft Reset — V[t] -= V_th(t) · s[t]
80
+ if self.detach_reset:
81
+ spike_d = spike.detach()
82
+ else:
83
+ spike_d = spike
84
+ self.v = self.v - spike_d * v_th
85
+
86
+ return spike
87
+
88
+ def extra_repr(self) -> str:
89
+ return (
90
+ f'v_reset={self.v_reset}, '
91
+ f'detach_reset={self.detach_reset}, '
92
+ f'step_mode={self.step_mode}, '
93
+ f'surrogate={self.surrogate_function}'
94
+ )
atomic_ops/snn_block.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SNNBlock: 完整的 SNN 隐状态空间 Block(并行化版本)
3
+
4
+ 结构(每个 SNN 时间步):
5
+ spike_in {0,1}^D
6
+ ├─→ W_in → I[t] ∈ R^{D*N}
7
+ ├─→ W_β^(x) + b_β → σ → β(t)
8
+ ├─→ W_α^(x) + b_α → softplus → α(t)
9
+ ├─→ W_th^(x) + b_th → |·|+V_min → V_th(t)
10
+ ├─→ W_gate → sigmoid → gate ∈ (0,1)^D
11
+ └─→ W_skip → I_skip ∈ R^D
12
+
13
+ SelectivePLIF(I, β, α, V_th) → s[t] ∈ {0,1}^{D*N}
14
+
15
+ W_out · V_post[t] ⊙ gate + I_skip → 连续输出 ∈ R^D
16
+
17
+ 数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
18
+ """
19
+
20
+ import math
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from spikingjelly.activation_based import base, layer, surrogate
25
+
26
+ from .selective_plif import SelectivePLIFNode
27
+ from .parallel_scan import plif_parallel_forward
28
+
29
+
30
+ # ====== Fused modulation activations (torch.compile) ======
31
+ # Fuse sigmoid + softplus + abs + alpha*I into single kernel.
32
+ # 7-8 separate element-wise kernels → 1 fused kernel, ~4x speedup on DN-sized tensors.
33
+ # First call triggers JIT compilation (~seconds); cached for subsequent calls.
34
+
35
+ @torch.compile(backend='inductor', fullgraph=True)
36
+ def _fused_modulation(raw_beta, b_beta, raw_alpha, b_alpha, raw_th, b_th, v_th_min, I_all):
37
+ beta = torch.sigmoid(raw_beta + b_beta)
38
+ alpha = F.softplus(raw_alpha + b_alpha)
39
+ v_th = v_th_min + torch.abs(raw_th + b_th)
40
+ u = alpha * I_all
41
+ return beta, u, v_th
42
+
43
+
44
+ class SNNBlock(base.MemoryModule):
45
+ """
46
+ 单个 SNN Block(并行化)。
47
+
48
+ Args:
49
+ D: 可见维度(Block 间通信的维度)
50
+ N: 状态扩展因子(每个通道的隐神经元数)
51
+ v_th_min: 动态阈值下限
52
+ surrogate_function: surrogate gradient 函数
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ D: int,
58
+ N: int = 8,
59
+ v_th_min: float = 0.1,
60
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
61
+ ):
62
+ super().__init__()
63
+ self.D = D
64
+ self.N = N
65
+ self.v_th_min = v_th_min
66
+ DN = D * N
67
+
68
+ # ====== 六条并行输入投影(SNN 突触:spike 输入) ======
69
+ self.W_in = layer.Linear(D, DN, bias=False, step_mode='s')
70
+ self.W_beta_x = layer.Linear(D, DN, bias=False, step_mode='s')
71
+ self.W_alpha_x = layer.Linear(D, DN, bias=False, step_mode='s')
72
+ self.W_th_x = layer.Linear(D, DN, bias=False, step_mode='s')
73
+ self.W_gate = layer.Linear(D, D, bias=False, step_mode='s')
74
+ self.W_skip = layer.Linear(D, D, bias=False, step_mode='s')
75
+
76
+ # ====== β/α/V_th 仅依赖 spike_in(无 W^(V)·V 项) ======
77
+
78
+ # ====== 调制偏置(结构化初始化) ======
79
+ self.b_beta = nn.Parameter(torch.empty(DN))
80
+ self.b_alpha = nn.Parameter(torch.empty(DN))
81
+ self.b_th = nn.Parameter(torch.empty(DN))
82
+
83
+ # ====== 输出投影:D*N → D(SNN 突触) ======
84
+ self.W_out = layer.Linear(DN, D, bias=False, step_mode='s')
85
+
86
+ # ====== 隐状态空间神经元(D*N 个,动态参数) ======
87
+ self.hidden_neuron = SelectivePLIFNode(
88
+ surrogate_function=surrogate_function,
89
+ detach_reset=False,
90
+ )
91
+
92
+ # ====== 参数初始化 ======
93
+ self._initialize_parameters()
94
+
95
+ def _initialize_parameters(self):
96
+ """功能引导初始化。"""
97
+ D, N = self.D, self.N
98
+ K_ref = 16
99
+
100
+ # 目标 β 分布:多时间尺度 [0.80, 0.99]
101
+ beta_values = torch.linspace(0.80, 0.99, N)
102
+
103
+ # ====== 1. β 偏置:logit-spaced + 维度间随机扰动 ======
104
+ b_beta_per_n = torch.log(beta_values / (1.0 - beta_values))
105
+ # 以 per_n 值为均值,加 N(0, 0.1) 扰动打破 D 个通道的对称性
106
+ self.b_beta.data.copy_(b_beta_per_n.repeat(D))
107
+ self.b_beta.data.add_(torch.empty_like(self.b_beta).normal_(0, 0.1))
108
+
109
+ # ====== 2. α 偏置:softplus(0.5413) ≈ 1.0 + 维度间随机扰动 ======
110
+ # 以 0.5413 为均值,N(0, 0.1) 扰动 → α ∈ ~[0.7, 1.3]
111
+ self.b_alpha.data.normal_(0.5413, 0.1)
112
+
113
+ # ====== 3. W^(x) 权重 ======
114
+ for lin in [self.W_in, self.W_gate, self.W_skip, self.W_out]:
115
+ nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
116
+ for lin in [self.W_beta_x, self.W_alpha_x, self.W_th_x]:
117
+ nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
118
+ lin.weight.data.mul_(0.1)
119
+
120
+ # ====== 4. W_in 时间尺度缩放 ======
121
+ scale_per_n = torch.sqrt(1.0 - beta_values ** 2) # (N,)
122
+ scale_DN = scale_per_n.repeat(D) # (D*N,)
123
+ with torch.no_grad():
124
+ self.W_in.weight.mul_(scale_DN.unsqueeze(1))
125
+
126
+ # ====== 5. b_th:σ_V 校准 ======
127
+ # σ_V = sqrt(p/3) * sqrt(1 - β^{2K})
128
+ # 其中 p 是输入 firing rate。旧版假设 p=0.5(σ_I=0.408),
129
+ # 但实际 input_neuron firing rate 约 0.07~0.45,深层更低。
130
+ # 用 p=0.15 保守估计,避免 v_th 过高导致死神经元。
131
+ p_assumed = 0.15
132
+ sigma_I_base = math.sqrt(p_assumed / 3.0)
133
+ sigma_V_per_n = sigma_I_base * torch.sqrt(
134
+ 1.0 - beta_values ** (2 * K_ref)
135
+ )
136
+ target_p_fire = torch.linspace(0.25, 0.08, N)
137
+ z_scores = math.sqrt(2.0) * torch.erfinv(
138
+ 2.0 * (1.0 - target_p_fire) - 1.0
139
+ )
140
+ target_V_th = sigma_V_per_n * z_scores
141
+ b_th_per_n = torch.clamp(target_V_th - self.v_th_min, min=0.05)
142
+ # 以 per_n 值为均值,加 N(0, 0.02) 扰动打破 D 个通道的对称性
143
+ self.b_th.data.copy_(b_th_per_n.repeat(D))
144
+ self.b_th.data.add_(torch.empty_like(self.b_th).normal_(0, 0.02))
145
+
146
+ # ====== 6. W_out 发放率均衡缩放 ======
147
+ out_scale_per_n = 1.0 / torch.sqrt(target_p_fire)
148
+ out_scale_per_n = out_scale_per_n / out_scale_per_n.mean()
149
+ out_scale_DN = out_scale_per_n.repeat(D)
150
+ with torch.no_grad():
151
+ self.W_out.weight.mul_(out_scale_DN.unsqueeze(0))
152
+
153
+ def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor:
154
+ """
155
+ 并行前向传播:使用 parallel scan 处理全序列。
156
+
157
+ Args:
158
+ spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike
159
+
160
+ Returns:
161
+ continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出(V_post 经 W_out 投影)
162
+ """
163
+ TK, batch, D = spike_in_seq.shape
164
+ DN = self.D * self.N
165
+
166
+ # ====== Phase 1: 批量投影(全部 TK 帧同时计算)======
167
+ flat = spike_in_seq.reshape(TK * batch, D)
168
+
169
+ I_all = F.linear(flat, self.W_in.weight).reshape(TK, batch, DN)
170
+ raw_beta = F.linear(flat, self.W_beta_x.weight).reshape(TK, batch, DN)
171
+ raw_alpha = F.linear(flat, self.W_alpha_x.weight).reshape(TK, batch, DN)
172
+ raw_th = F.linear(flat, self.W_th_x.weight).reshape(TK, batch, DN)
173
+ gate_all = torch.sigmoid(
174
+ F.linear(flat, self.W_gate.weight).reshape(TK, batch, D)
175
+ )
176
+ I_skip_all = F.linear(flat, self.W_skip.weight).reshape(TK, batch, D)
177
+
178
+ # ====== Phase 1b: 融合激活(torch.compile → 单 kernel)======
179
+ beta_all, u_hidden, v_th_all = _fused_modulation(
180
+ raw_beta, self.b_beta, raw_alpha, self.b_alpha,
181
+ raw_th, self.b_th, self.v_th_min, I_all,
182
+ )
183
+
184
+ # 获取隐神经元初始状态
185
+ v_init_hidden = self.hidden_neuron.v
186
+ if isinstance(v_init_hidden, float):
187
+ v_init_hidden = torch.zeros(batch, DN, device=flat.device, dtype=flat.dtype)
188
+
189
+ s_hidden, V_post_hidden, _ = plif_parallel_forward(
190
+ beta_all, u_hidden, v_th_all, v_init_hidden, max_iter=3,
191
+ surrogate_function=self.hidden_neuron.surrogate_function,
192
+ )
193
+
194
+ # 更新隐神经元状态(保存末步供下次调用)
195
+ self.hidden_neuron.v = V_post_hidden[-1].detach()
196
+
197
+ # ====== Phase 4: 输出投影(V_post → W_out: 连续梯度直通 β)======
198
+ # 用 V_post(膜电压)代替 spike 作为 W_out 输入,消除 surrogate 梯度瓶颈:
199
+ # spike 路径: ∂spike/∂β = surrogate'(V-v_th) · V_prev ≈ 0(大部分时刻)
200
+ # V_post 路径: ∂V_post/∂β = V_prev(无 surrogate 阻断,每步都有梯度)
201
+ v_flat = V_post_hidden.reshape(TK * batch, DN)
202
+ I_out_all = F.linear(v_flat, self.W_out.weight).reshape(TK, batch, D)
203
+ I_total_all = I_out_all * gate_all + I_skip_all # (TK, batch, D)
204
+
205
+ # output_neuron 已移除:连续值由层级 K 帧聚合处理
206
+ return I_total_all # (TK, batch, D), 连续值
207
+
208
+ def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor:
209
+ """
210
+ 单步前向传播(用于调试/兼容)。
211
+
212
+ Args:
213
+ spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1}
214
+
215
+ Returns:
216
+ continuous_out: 连续输出, shape (batch, D)
217
+ """
218
+ V_prev = self.hidden_neuron.v
219
+ if isinstance(V_prev, float):
220
+ V_prev = torch.zeros(
221
+ spike_in.shape[0], self.D * self.N,
222
+ device=spike_in.device, dtype=spike_in.dtype,
223
+ )
224
+
225
+ I_t = self.W_in(spike_in)
226
+
227
+ # β 调制仅依赖 spike_in
228
+ beta = torch.sigmoid(self.W_beta_x(spike_in) + self.b_beta)
229
+ alpha = F.softplus(self.W_alpha_x(spike_in) + self.b_alpha)
230
+ v_th = self.v_th_min + torch.abs(self.W_th_x(spike_in) + self.b_th)
231
+
232
+ gate = torch.sigmoid(self.W_gate(spike_in))
233
+ I_skip = self.W_skip(spike_in)
234
+
235
+ s_hidden = self.hidden_neuron(I_t, beta, alpha, v_th)
236
+
237
+ # 用 V_post(膜电压)做输出投影,与 forward_parallel 一致
238
+ V_post = self.hidden_neuron.v # 发放+重置后的膜电位
239
+ I_out = self.W_out(V_post)
240
+ I_total = I_out * gate + I_skip
241
+
242
+ return I_total # 连续值
atomic_ops/snn_decoder_layer.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SNNDecoderLayer: 单个 SNN 解码层(Pre-LN 连续残差流 + 动态 K 帧聚合)
3
+
4
+ RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → 残差
5
+ RMSNorm → PLIF → SNNFFN → 动态K聚合 → out_proj → 残差
6
+
7
+ 动态 K:
8
+ - K 是最大步数(K_max),不是固定步数。不同 token 有效步数 ∈ [1, K_max]。
9
+ - 每个 token 的 K 帧 SNN 输出,学习自适应停止概率 p_halt
10
+ - PonderNet 几何分布加权:λ_k = p_k · ∏_{j<k}(1-p_j),归一化后加权聚合
11
+ - 不同 token 有效步数不同:简单 token 早停(E[K]小),复杂 token 用满步数
12
+ - ponder_cost 正则化:鼓励用更少步数完成简单 token 的处理
13
+
14
+ 数学推导:
15
+ 停止概率: p_k = σ(halt_proj(frame_k)) ∈ (0,1)
16
+ 生存概率: S_k = ∏_{j=1}^{k-1} (1 - p_j) — 到第 k 步还没停
17
+ 权重: λ_k = p_k · S_k — 恰好在第 k 步停止的概率
18
+ 归一化: λ̂_k = λ_k / Σ_k λ_k — 确保权重和为 1
19
+ 聚合: output = Σ_k λ̂_k · frame_k
20
+ 代价: E[K] = Σ_k k · λ̂_k — 期望步数
21
+
22
+ K_max 设计原则:
23
+ K_max 越大,模型对复杂 token 的处理能力越强(更多步数可用),
24
+ 但计算量和显存线性增长。K_max=32 允许 token 使用 1~32 步。
25
+ PonderNet 的 ponder_cost 正则化确保简单 token 不浪费步数。
26
+
27
+ K 帧层间聚合:
28
+ - SNN 子层输出 K 帧连续值(V_post 经投影),PonderNet 加权聚合为 1 per token
29
+ - 聚合后经 out_proj 投影,广播回 K 帧做残差
30
+ - 使 β 的时间动力学通过 K 帧聚合梯度有效传播
31
+
32
+ 对标 Qwen3DecoderLayer(Pre-LN 模式完全等价):
33
+ Qwen3: RMSNorm → Attention → residual → RMSNorm → MLP → residual
34
+ SNN: RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → residual
35
+ → RMSNorm → PLIF → SNNFFN → 动态K聚合 → out_proj → residual
36
+ """
37
+
38
+ import math
39
+
40
+ import torch
41
+ import torch.nn as nn
42
+ import torch.nn.functional as F
43
+ from spikingjelly.activation_based import base, surrogate
44
+
45
+ from .plif_node import PLIFNode
46
+ from .rms_norm import RMSNorm
47
+ from .snn_block import SNNBlock
48
+ from .snn_ffn import SNNFFN
49
+ from .parallel_scan import plif_rowparam_forward
50
+
51
+
52
+ # ====== Fused halt weight computation (torch.compile) ======
53
+ # 7-8 个独立 element-wise kernel → 单 fused kernel
54
+ # sigmoid + clamp + log1p + cumsum + exp + normalize
55
+ # 首次调用触发 JIT 编译(~秒级),后续调用走缓存
56
+
57
+ @torch.compile(backend='inductor', fullgraph=True)
58
+ def _fused_geometric_halt(halt_logits):
59
+ """融合计算 PonderNet 几何分布停止权重。
60
+
61
+ 输入: halt_logits (seq_len, K, batch) — halt_proj 的原始输出
62
+ 输出: halt_weights (seq_len, K, batch) — 归一化几何分布权重,sum=1
63
+
64
+ 数学: p_k = σ(logit_k), S_k = ∏_{j<k}(1-p_j), λ_k = p_k·S_k, λ̂_k = λ_k/Σλ
65
+ """
66
+ p_halt = torch.sigmoid(halt_logits).clamp(min=1e-7, max=1.0 - 1e-7)
67
+ log_1_minus_p = torch.log1p(-p_halt) # (seq_len, K, batch)
68
+ # Exclusive cumsum: log_survive[:, k, :] = Σ_{j<k} log(1-p_j)
69
+ # 避免 torch.cat: 用 cumsum([:, :-1]) 填充 [:, 1:]
70
+ log_survive = torch.zeros_like(log_1_minus_p)
71
+ log_survive[:, 1:, :] = torch.cumsum(log_1_minus_p[:, :-1, :], dim=1)
72
+ survive = torch.exp(log_survive) # (seq_len, K, batch)
73
+ halt_weights = p_halt * survive # λ_k = p_k · S_k
74
+ halt_weights = halt_weights / (halt_weights.sum(dim=1, keepdim=True) + 1e-8)
75
+ return halt_weights
76
+
77
+
78
+ class SNNDecoderLayer(base.MemoryModule):
79
+ """
80
+ 单个 SNN 解码层(连续残差流 + K 帧聚合版本)。
81
+
82
+ 层间传递连续值 h (TK, batch, D),通过 PLIF 神经元转换为 spike,
83
+ 输入 SNN 子层处理后,K 帧聚合为 1 per token,经 out_proj 投影,
84
+ 广播回 K 帧做残差连接。
85
+
86
+ K 帧聚合使 β 的时间动力学(控制 K 步内的膜电位演化)产生可微分的
87
+ token 级效应,解决 β 梯度为纯噪声的问题。
88
+
89
+ Args:
90
+ D: 可见维度
91
+ N: 状态扩展因子
92
+ D_ff: FFN 中间层维度
93
+ v_th_min: SNNBlock 动态阈值下限
94
+ ffn_v_threshold: SNNFFN gate/up 神经元阈值
95
+ K: 每 token 的 SNN 时间步数
96
+ num_layers: 总层数(用于残差输出缩放 + SNNFFN down_proj 缩放)
97
+ layer_idx: 当前层索引
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ D: int,
103
+ N: int,
104
+ D_ff: int,
105
+ v_th_min: float,
106
+ ffn_v_threshold: float,
107
+ K: int = 16,
108
+ num_layers: int = 1,
109
+ layer_idx: int = 0,
110
+ ):
111
+ super().__init__()
112
+ self.D = D
113
+ self.K = K
114
+
115
+ self.snn_block = SNNBlock(
116
+ D=D, N=N, v_th_min=v_th_min,
117
+ )
118
+ self.snn_ffn = SNNFFN(
119
+ D=D, D_ff=D_ff,
120
+ output_v_threshold=ffn_v_threshold,
121
+ num_layers=num_layers,
122
+ layer_idx=layer_idx,
123
+ )
124
+
125
+ # Pre-LN 分支归一化: h → RMSNorm → PLIFNode
126
+ self.block_norm = RMSNorm(D)
127
+ self.ffn_norm = RMSNorm(D)
128
+
129
+ # 输入神经元: RMSNorm(h) → V_post 膜电位激活(D 维可学习 β 和 V_th)
130
+ self.input_neuron1 = PLIFNode(
131
+ dim=D,
132
+ init_tau=2.0,
133
+ v_threshold=0.5,
134
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
135
+ )
136
+ self.input_neuron2 = PLIFNode(
137
+ dim=D,
138
+ init_tau=2.0,
139
+ v_threshold=0.5,
140
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
141
+ )
142
+
143
+ # 输出投影(突触): spike (D) → 连续空间 (D)
144
+ self.block_out_proj = nn.Linear(D, D, bias=False)
145
+ self.ffn_out_proj = nn.Linear(D, D, bias=False)
146
+
147
+ # ====== 动态 K: 停止投影(突触: SNN 输出 → 停止概率) ======
148
+ # halt_proj: D → 1,每步每 token 产生一个停止 logit
149
+ # PonderNet 几何分布加权,替代 uniform mean 聚合
150
+ self.block_halt = nn.Linear(D, 1, bias=True)
151
+ self.ffn_halt = nn.Linear(D, 1, bias=True)
152
+
153
+ # 残差输出缩放初始化(GPT-2 style: σ = 0.02 / √(2·num_layers))
154
+ std = 0.02 / math.sqrt(2 * num_layers)
155
+ nn.init.normal_(self.block_out_proj.weight, std=std)
156
+ nn.init.normal_(self.ffn_out_proj.weight, std=std)
157
+
158
+ # halt 初始化: 小权重 + 负偏置 → p_halt ≈ 0.03 → 接近 uniform 聚合
159
+ # σ(-3.5) ≈ 0.029, 几何分布归一化后 λ_1/λ_K ≈ 1.5, 接近均匀
160
+ for halt in [self.block_halt, self.ffn_halt]:
161
+ nn.init.xavier_uniform_(halt.weight)
162
+ halt.weight.data.mul_(0.01)
163
+ nn.init.constant_(halt.bias, -3.5)
164
+
165
+ def _input_neuron_parallel(self, input_neuron, x):
166
+ """
167
+ 输入 PLIF 神经元的 parallel scan 前向传播。
168
+
169
+ 完整 PLIF 动力学: V[t] = β·V[t-1] + (1-β)·x[t], spike = Θ(V-V_th), 软重置。
170
+ 输出膜电位泄漏量 (1-β)·V_post 作为激活值——即每步因指数衰减将泄漏的量。
171
+ 相比直接传递 V_post,泄漏量自然强调快响应神经元(大 1-β),
172
+ 抑制慢记忆神经元(小 1-β),实现隐式的时间尺度加权。
173
+
174
+ Args:
175
+ input_neuron: PLIFNode 实例(D 维可学习 β 和 V_th)
176
+ x: (TK, batch, D) — 连续值输入
177
+
178
+ Returns:
179
+ leak: (TK, batch, D) — 膜电位泄漏量 (1-β)·V_post
180
+ """
181
+ TK, batch, D = x.shape
182
+
183
+ beta = input_neuron.beta # (D,)
184
+ u = (1.0 - beta) * x # (D,) broadcast → (TK, batch, D)
185
+
186
+ v_init = input_neuron.v
187
+ if isinstance(v_init, float):
188
+ v_init = torch.zeros(batch, D, device=x.device, dtype=x.dtype)
189
+
190
+ beta_row = beta.unsqueeze(0).expand(batch, D).contiguous()
191
+ v_th_row = input_neuron.v_th.unsqueeze(0).expand(batch, D).contiguous()
192
+
193
+ spike, V_post = plif_rowparam_forward(
194
+ beta_row, u, v_th_row, v_init,
195
+ surrogate_function=input_neuron.surrogate_function,
196
+ )
197
+
198
+ input_neuron.v = V_post[-1].detach()
199
+ return (1.0 - beta) * V_post # 膜电位泄漏量
200
+
201
+ def _adaptive_aggregate(self, frames, halt_proj):
202
+ """
203
+ PonderNet 式自适应 K 帧聚合(动态 K 核心,torch.compile 融合优化)。
204
+
205
+ 每步计算停止概率 p_k,用几何分布权重加权聚合,
206
+ 使不同 token 有不同的有效步数。
207
+
208
+ 优化: _fused_geometric_halt 将 sigmoid+log1p+cumsum+exp+normalize
209
+ 融合为单 inductor kernel(参见 snn_block._fused_modulation 同一模式)。
210
+
211
+ 数学:
212
+ p_k = σ(halt_proj(frame_k)) — 停止概率
213
+ S_k = ∏_{j<k} (1-p_j) — 生存概率
214
+ λ_k = p_k · S_k — 几何分布权重
215
+ λ̂_k = λ_k / Σ λ_k — 归一化
216
+ output = Σ λ̂_k · frame_k — 加权聚合
217
+ E[K] = Σ k · λ̂_k — 期望步数(ponder cost)
218
+
219
+ Args:
220
+ frames: (seq_len, K, batch, D) — SNN 子层 K 帧输出
221
+ halt_proj: nn.Linear(D, 1) — 停止投影(突触)
222
+
223
+ Returns:
224
+ aggregated: (seq_len, batch, D) — 加权聚合结果
225
+ ponder_cost: scalar — 期望步数均值(正则化用)
226
+ """
227
+ seq_len, K, batch, D = frames.shape
228
+
229
+ # ====== 1. halt_proj matmul(cuBLAS)+ 融合几何权重(inductor) ======
230
+ halt_logits = halt_proj(frames).squeeze(-1) # (seq_len, K, batch)
231
+ halt_weights = _fused_geometric_halt(halt_logits) # (seq_len, K, batch), 归一化
232
+
233
+ # ====== 2. 加权聚合 ======
234
+ # (seq_len, K, batch, 1) × (seq_len, K, batch, D) → sum → (seq_len, batch, D)
235
+ aggregated = (frames * halt_weights.unsqueeze(-1)).sum(dim=1)
236
+
237
+ # ====== 3. Ponder cost: E[K] per token ======
238
+ steps = torch.arange(1, K + 1, device=frames.device, dtype=frames.dtype)
239
+ expected_k = (halt_weights * steps[None, :, None]).sum(dim=1) # (seq_len, batch)
240
+ ponder_cost = expected_k.mean() # scalar
241
+
242
+ return aggregated, ponder_cost, expected_k.detach()
243
+
244
+ def forward_parallel(self, h):
245
+ """
246
+ 并行前向传播:连续残差流 + 动态 K 帧聚合。
247
+
248
+ SNN 子层在 TK 维度处理(K 步时间动力学),输出后用 PonderNet
249
+ 自适应聚合 K 帧(不同 token 有效步数不同),经 out_proj 投影后
250
+ 广播回 TK 做残差。
251
+
252
+ Args:
253
+ h: (TK, batch, D) — 连续值输入
254
+
255
+ Returns:
256
+ h: (TK, batch, D) — 连续值输出
257
+ ponder_cost: scalar — 两个子层的平均期望步数(正则化用)
258
+ """
259
+ TK, batch, D = h.shape
260
+ K = self.K
261
+ seq_len = TK // K
262
+
263
+ # 子层 1: SNNBlock — RMSNorm → PLIFNode(V_post) → SNNBlock → 动态K聚合 → out_proj → 残差
264
+ v_in = self._input_neuron_parallel(self.input_neuron1, self.block_norm(h))
265
+ cont_block = self.snn_block.forward_parallel(v_in) # (TK, batch, D), 连续值
266
+
267
+ # 动态 K 帧聚合(PonderNet): (TK, batch, D) → (seq_len, K, batch, D) → 加权 → (seq_len, batch, D)
268
+ frames_block = cont_block.view(seq_len, K, batch, D)
269
+ combined_block, pc_block, ek_block = self._adaptive_aggregate(frames_block, self.block_halt)
270
+ res_block = self.block_out_proj(combined_block) # (seq_len, batch, D)
271
+ res_block = res_block - res_block.mean(dim=-1, keepdim=True) # 残差中心化
272
+
273
+ # 广播回 TK:每 token 的残差复制 K 份
274
+ h = h + res_block.repeat_interleave(K, dim=0)
275
+
276
+ # 子层 2: SNNFFN — RMSNorm → PLIFNode(V_post) → SNNFFN → 动态K聚合 → out_proj → 残差
277
+ v_in2 = self._input_neuron_parallel(self.input_neuron2, self.ffn_norm(h))
278
+ cont_ffn = self.snn_ffn.forward_parallel(v_in2) # (TK, batch, D), 连续值
279
+
280
+ frames_ffn = cont_ffn.view(seq_len, K, batch, D)
281
+ combined_ffn, pc_ffn, ek_ffn = self._adaptive_aggregate(frames_ffn, self.ffn_halt)
282
+ res_ffn = self.ffn_out_proj(combined_ffn)
283
+ res_ffn = res_ffn - res_ffn.mean(dim=-1, keepdim=True)
284
+
285
+ h = h + res_ffn.repeat_interleave(K, dim=0)
286
+
287
+ ponder_cost = (pc_block + pc_ffn) / 2.0 # 两个子层平均
288
+
289
+ # 存储 per-token E[K] 范围(诊断用,不影响计算图)
290
+ # ek_block/ek_ffn: (seq_len, batch), detached
291
+ with torch.no_grad():
292
+ all_ek = torch.cat([ek_block.flatten(), ek_ffn.flatten()])
293
+ self._ek_min = all_ek.min().item()
294
+ self._ek_max = all_ek.max().item()
295
+
296
+ return h, ponder_cost
297
+
298
+ def single_step_forward(self, h):
299
+ """
300
+ 单步前向传播:连续残差流。
301
+
302
+ 注意:单步模式无法做动态 K 聚合(每步独立处理)。
303
+ 训练和推理均使用 forward_parallel(含动态 K 聚合)。
304
+ 此方法仅用于调试。
305
+
306
+ Args:
307
+ h: (batch, D) — 连续值输入
308
+
309
+ Returns:
310
+ h: (batch, D) — 连续值输出
311
+ ponder_cost: scalar — 0.0(单步无 ponder cost)
312
+ """
313
+ # 子层 1: SNNBlock — RMSNorm → PLIFNode(leak) → SNNBlock → out_proj → 残差
314
+ _ = self.input_neuron1(self.block_norm(h)) # 触发 PLIF 动力学,更新 .v
315
+ v_in = (1.0 - self.input_neuron1.beta) * self.input_neuron1.v # 膜电位泄漏量
316
+ cont_block = self.snn_block.single_step_forward(v_in)
317
+ res_block = self.block_out_proj(cont_block)
318
+ h = h + res_block - res_block.mean(dim=-1, keepdim=True)
319
+
320
+ # 子层 2: SNNFFN — RMSNorm → PLIFNode(leak) → SNNFFN → out_proj → 残差
321
+ _ = self.input_neuron2(self.ffn_norm(h))
322
+ v_in2 = (1.0 - self.input_neuron2.beta) * self.input_neuron2.v # 膜电位泄漏量
323
+ cont_ffn = self.snn_ffn.single_step_forward(v_in2)
324
+ res_ffn = self.ffn_out_proj(cont_ffn)
325
+ h = h + res_ffn - res_ffn.mean(dim=-1, keepdim=True)
326
+
327
+ return h, torch.tensor(0.0, device=h.device)
atomic_ops/snn_ffn.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SNNFFN: SNN 等价的 Feed-Forward Network
3
+
4
+ 对标 Qwen3MLP 的 SwiGLU 结构:
5
+ Qwen3 MLP: down_proj( SiLU(gate_proj(x)) * up_proj(x) )
6
+ SNN FFN: down_proj( gate_V_post * up_V_post ) + skip
7
+
8
+ 膜电位门控(对标 SiLU gating):
9
+ gate/up 神经元完整 PLIF 动力学(积分+阈值+重置),
10
+ 输出膜电位 V_post 做连续乘法门控,替代 binary AND 门。
11
+
12
+ 信号流:
13
+ x → gate_proj → gate_neuron → V_post_gate
14
+ x → up_proj → up_neuron → V_post_up
15
+ V_post_gate × V_post_up → gated
16
+ down_proj(gated) + skip_proj(x) → 连续输出
17
+ """
18
+
19
+ import math
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from spikingjelly.activation_based import base, layer, surrogate
24
+
25
+ from .plif_node import PLIFNode
26
+ from .parallel_scan import plif_rowparam_forward
27
+
28
+
29
+ class SNNFFN(base.MemoryModule):
30
+ """
31
+ SNN 等价的 Feed-Forward Network。
32
+
33
+ Args:
34
+ D: 可见维度(输入/输出 spike 维度)
35
+ D_ff: 中间层维度(对标 Qwen3 intermediate_size)
36
+ output_v_threshold: 输出神经元阈值
37
+ num_layers: 总层数,用于 down_proj 缩放
38
+ layer_idx: 当前层索引
39
+ surrogate_function: surrogate gradient 函数
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ D: int,
45
+ D_ff: int,
46
+ output_v_threshold: float = 0.3,
47
+ num_layers: int = 1,
48
+ layer_idx: int = 0,
49
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
50
+ ):
51
+ super().__init__()
52
+ self.D = D
53
+ self.D_ff = D_ff
54
+
55
+ # ====== 三条投影路径(对标 SwiGLU: gate_proj, up_proj, down_proj) ======
56
+ self.gate_proj = layer.Linear(D, D_ff, bias=False, step_mode='s')
57
+ self.up_proj = layer.Linear(D, D_ff, bias=False, step_mode='s')
58
+ self.down_proj = layer.Linear(D_ff, D, bias=False, step_mode='s')
59
+
60
+ # ====== 残差路径 ======
61
+ self.skip_proj = layer.Linear(D, D, bias=False, step_mode='s')
62
+
63
+ # ====== 神经元(D 维或 D_ff 维可学习 β 和 V_th) ======
64
+ # gate_neuron: 门控发放
65
+ self.gate_neuron = PLIFNode(
66
+ dim=D_ff,
67
+ init_tau=2.0,
68
+ v_threshold=output_v_threshold,
69
+ surrogate_function=surrogate_function,
70
+ )
71
+ # up_neuron: 值发放
72
+ self.up_neuron = PLIFNode(
73
+ dim=D_ff,
74
+ init_tau=2.0,
75
+ v_threshold=output_v_threshold,
76
+ surrogate_function=surrogate_function,
77
+ )
78
+ # ====== 参数初始化 ======
79
+ self._initialize_parameters(num_layers)
80
+
81
+ def _initialize_parameters(self, num_layers: int):
82
+ """初始化投影权重。
83
+
84
+ - gate_proj, up_proj, skip_proj: Kaiming uniform
85
+ - down_proj: Kaiming uniform × 1/√(num_layers),防深层梯度爆炸
86
+ """
87
+ for lin in [self.gate_proj, self.up_proj, self.skip_proj]:
88
+ nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
89
+
90
+ nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
91
+ self.down_proj.weight.data.mul_(1.0 / math.sqrt(num_layers))
92
+
93
+ def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor:
94
+ """
95
+ 并行前向传播:使用 parallel scan 处理全序列。
96
+
97
+ 优化:
98
+ - gate_proj + up_proj 合并为单次 matmul(2 launch → 1)
99
+ - gate + up PLIF scan: row-param kernel(无需 expand+contiguous beta/v_th)
100
+ - u_merged: 向量缩放替代 cat(1次 broadcast multiply 替代 2次 scale + 1次 cat)
101
+
102
+ Args:
103
+ spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike
104
+
105
+ Returns:
106
+ continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出
107
+ """
108
+ TK, batch, D = spike_in_seq.shape
109
+ D_ff = self.D_ff
110
+ flat = spike_in_seq.reshape(TK * batch, D)
111
+
112
+ # ====== Phase 1: 批量投影(gate+up 合并为 1 次 matmul) ======
113
+ W_gate_up = torch.cat([self.gate_proj.weight, self.up_proj.weight], dim=0)
114
+ I_gate_up = F.linear(flat, W_gate_up).reshape(TK, batch, 2 * D_ff)
115
+ I_skip = F.linear(flat, self.skip_proj.weight).reshape(TK, batch, D)
116
+
117
+ # ====== Phase 2: Gate+Up 合并 PLIF scan(row-param kernel) ======
118
+ beta_gate = self.gate_neuron.beta # (D_ff,)
119
+ beta_up = self.up_neuron.beta # (D_ff,)
120
+ surr = self.gate_neuron.surrogate_function
121
+
122
+ # u_merged: 向量缩放(D_ff 维 β 直接 cat,无需 expand)
123
+ scale_row = torch.cat([1.0 - beta_gate, 1.0 - beta_up]) # (2*D_ff,)
124
+ u_merged = I_gate_up * scale_row # (TK, batch, 2*D_ff), broadcast
125
+
126
+ # beta_row / v_th_row: (batch, 2*D_ff) — D_ff 维可学习参数
127
+ beta_row = torch.cat([beta_gate, beta_up]) # (2*D_ff,)
128
+ beta_row = beta_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous()
129
+
130
+ v_th_row = torch.cat([self.gate_neuron.v_th, self.up_neuron.v_th]) # (2*D_ff,)
131
+ v_th_row = v_th_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous()
132
+
133
+ # v_init_merged: (batch, 2*D_ff)
134
+ v_init_gate = self.gate_neuron.v
135
+ if isinstance(v_init_gate, float):
136
+ v_init_gate = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype)
137
+ v_init_up = self.up_neuron.v
138
+ if isinstance(v_init_up, float):
139
+ v_init_up = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype)
140
+ v_init_merged = torch.cat([v_init_gate, v_init_up], dim=-1)
141
+
142
+ # Row-param PLIF scan: beta/v_th 从寄存器读取,不占显存带宽
143
+ spike_merged, V_post_merged = plif_rowparam_forward(
144
+ beta_row, u_merged, v_th_row, v_init_merged,
145
+ surrogate_function=surr,
146
+ )
147
+
148
+ # 膜电位泄漏量作为激活值: leak = (1-β) · V_post
149
+ gate_leak = V_post_merged[:, :, :D_ff] * (1.0 - beta_gate) # (TK, batch, D_ff)
150
+ up_leak = V_post_merged[:, :, D_ff:] * (1.0 - beta_up) # (TK, batch, D_ff)
151
+ self.gate_neuron.v = V_post_merged[-1, :, :D_ff].detach()
152
+ self.up_neuron.v = V_post_merged[-1, :, D_ff:].detach()
153
+
154
+ # ====== Phase 3: 连续门控(leak × leak,对标 SwiGLU)+ 降维 ======
155
+ gated = gate_leak * up_leak # (TK, batch, D_ff)
156
+ gated_flat = gated.reshape(TK * batch, D_ff)
157
+ I_out = F.linear(gated_flat, self.down_proj.weight).reshape(TK, batch, D) + I_skip
158
+
159
+ # output_neuron 已移除:连续值由层级 K 帧聚合处理
160
+ return I_out # (TK, batch, D), 连续值
161
+
162
+ def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor:
163
+ """
164
+ 单步前向传播。
165
+
166
+ Args:
167
+ spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1}
168
+
169
+ Returns:
170
+ continuous_out: 连续输出, shape (batch, D)
171
+ """
172
+ # 门控路径 — 膜电位泄漏量激活
173
+ _ = self.gate_neuron(self.gate_proj(spike_in))
174
+ gate_leak = (1.0 - self.gate_neuron.beta) * self.gate_neuron.v # leak
175
+
176
+ # 值路径 — 膜电位泄漏量激活
177
+ _ = self.up_neuron(self.up_proj(spike_in))
178
+ up_leak = (1.0 - self.up_neuron.beta) * self.up_neuron.v # leak
179
+
180
+ # 连续门控(对标 SwiGLU)
181
+ gated = gate_leak * up_leak
182
+
183
+ # 降维 + 残差
184
+ I_out = self.down_proj(gated) + self.skip_proj(spike_in) # R^D
185
+ return I_out # 连续值
config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "neuronspark",
3
+ "architectures": [
4
+ "NeuronSparkForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_neuronspark.NeuronSparkConfig",
8
+ "AutoModelForCausalLM": "modeling_neuronspark.NeuronSparkForCausalLM"
9
+ },
10
+ "vocab_size": 6144,
11
+ "D": 896,
12
+ "N": 8,
13
+ "K": 16,
14
+ "num_layers": 20,
15
+ "D_ff": 2688,
16
+ "v_th_min": 0.1,
17
+ "torch_dtype": "float32",
18
+ "transformers_version": "4.52.0",
19
+ "_training_step": 85000,
20
+ "_tokens_seen": 203910528
21
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework": "pytorch", "task": "other"}
configuration_neuronspark.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NeuronSpark 模型配置。"""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class NeuronSparkConfig(PretrainedConfig):
7
+ """
8
+ SNN 隐状态空间语言模型配置。
9
+
10
+ Args:
11
+ vocab_size: 词表大小
12
+ D: 隐层维度
13
+ N: 状态扩展因子(每通道隐神经元数)
14
+ K: 每 token 最大 SNN 时间步(PonderNet 动态决定有效步数)
15
+ num_layers: SNN 解码层数
16
+ D_ff: FFN 中间层维度
17
+ v_th_min: 动态阈值下限
18
+ """
19
+ model_type = "neuronspark"
20
+
21
+ def __init__(
22
+ self,
23
+ vocab_size=6144,
24
+ D=896,
25
+ N=8,
26
+ K=16,
27
+ num_layers=20,
28
+ D_ff=2688,
29
+ v_th_min=0.1,
30
+ **kwargs,
31
+ ):
32
+ self.D = D
33
+ self.N = N
34
+ self.K = K
35
+ self.num_layers = num_layers
36
+ self.D_ff = D_ff
37
+ self.v_th_min = v_th_min
38
+ super().__init__(vocab_size=vocab_size, **kwargs)
model.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SNNLanguageModel: SNN 隐状态空间语言模型(全膜电位 + 动态 K)
3
+
4
+ 架构(三段式):
5
+ model.encode(token_ids) → h_seq # 输入: embed → repeat K 次(可微分)
6
+ model.snn_forward(h_seq) → h_out, pc # SNN 核心: 20 层,全膜电位 + 动态 K 聚合
7
+ model.decode(h_out, seq) → logits # 输出: output_neuron(V_post) → K帧mean → proj → logits
8
+
9
+ 核心设计:
10
+ 1. 膜电位泄漏量:PLIFNode 输出 (1-β)·V_post(泄漏量),自然强调快响应神经元
11
+ 2. 动态 K:PonderNet 自适应停止,不同 token 不同有效步数
12
+ - 每层每子层学习 halt_proj(D→1),从 SNN 输出逐步计算停止概率
13
+ - 几何分布权重加权聚合,替代 uniform mean
14
+ - ponder_cost 正则化鼓励早停
15
+
16
+ 数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
17
+ """
18
+
19
+ import math
20
+ from dataclasses import dataclass
21
+ from typing import Optional
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from spikingjelly.activation_based import functional, surrogate
27
+ from torch.utils.checkpoint import checkpoint
28
+
29
+ from atomic_ops import SNNDecoderLayer
30
+ from atomic_ops.plif_node import PLIFNode
31
+ from atomic_ops.rms_norm import RMSNorm
32
+ from atomic_ops.parallel_scan import plif_rowparam_forward
33
+ # fp16_encode/fp16_decode 已移除: 全膜电位架构不需要 spike 编解码
34
+ from atomic_ops.lateral_inhibition import LateralInhibition
35
+
36
+
37
+ @dataclass
38
+ class SNNModelOutput:
39
+ """模型输出容器,对齐教程 CausalLMOutputWithPast 接口。"""
40
+ last_loss: Optional[torch.Tensor] = None
41
+ logits: Optional[torch.Tensor] = None
42
+ ponder_cost: Optional[torch.Tensor] = None # 动态 K: 平均期望步数
43
+
44
+
45
+ class SNNLanguageModel(nn.Module):
46
+ """
47
+ 从零训练的 SNN 隐状态空间语言模型(parallel scan)。
48
+
49
+ Args:
50
+ vocab_size: 词表大小(默认 6144,自训练 BPE)
51
+ D: 可见维度
52
+ N: 状态扩展因子
53
+ K: 每 token 最大 SNN 时间步数(K_max)。PonderNet 动态决定有效步数 ∈ [1, K]。
54
+ K 越大 → 复杂 token 可用更多步数,但计算量和显存线性增长。
55
+ num_layers: SNN 解码层数
56
+ D_ff: FFN 中间层维度
57
+ v_th_min: 动态阈值下限
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ vocab_size: int = 6144,
63
+ D: int = 1024,
64
+ N: int = 8,
65
+ K: int = 32,
66
+ num_layers: int = 20,
67
+ D_ff: int = 3072,
68
+ v_th_min: float = 0.1,
69
+ ):
70
+ super().__init__()
71
+ self.vocab_size = vocab_size
72
+ self.D = D
73
+ self.N = N
74
+ self.K = K
75
+ self.num_layers = num_layers
76
+ self.D_ff = D_ff
77
+
78
+ # ====== Embedding + Norm(全部可训练)======
79
+ self.embed_tokens = nn.Embedding(vocab_size, D)
80
+ self.norm = LateralInhibition(D)
81
+
82
+ # ====== 解码投影 ======
83
+ self.decode_proj = nn.Linear(D, D)
84
+
85
+ # ====== 输出 RMSNorm + 输出神经元 ======
86
+ self.output_norm = RMSNorm(D)
87
+ self.output_neuron = PLIFNode(
88
+ dim=D,
89
+ init_tau=2.0,
90
+ v_threshold=0.3,
91
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
92
+ )
93
+
94
+ # ====== SNN Decoder Layers ======
95
+ self.layers = nn.ModuleList([
96
+ SNNDecoderLayer(
97
+ D=D, N=N, D_ff=D_ff, v_th_min=v_th_min,
98
+ ffn_v_threshold=0.15,
99
+ K=K,
100
+ num_layers=num_layers,
101
+ layer_idx=i,
102
+ )
103
+ for i in range(num_layers)
104
+ ])
105
+
106
+ self._init_weights()
107
+
108
+ def _init_weights(self):
109
+ """初始化所有可训练权重(从零训练)。"""
110
+ nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=0.02)
111
+ nn.init.xavier_uniform_(self.decode_proj.weight)
112
+ nn.init.zeros_(self.decode_proj.bias)
113
+
114
+ def encode(self, token_ids: torch.Tensor) -> torch.Tensor:
115
+ """输入边界:token_ids → 连续值序列。
116
+
117
+ Embedding lookup,每 token 重复 K 次作为 SNN 时间步输入。
118
+ 梯度可通过 embedding 直接反传。
119
+
120
+ Returns: (seq_len*K, batch, D), 连续值
121
+ """
122
+ emb = self.embed_tokens(token_ids) # (batch, seq_len, D)
123
+ batch, seq_len, D = emb.shape
124
+ # 每 token 重复 K 次: (batch, seq_len, D) → (batch, seq_len*K, D) → (TK, batch, D)
125
+ emb_k = emb.unsqueeze(2).expand(-1, -1, self.K, -1).reshape(batch, seq_len * self.K, D)
126
+ return emb_k.permute(1, 0, 2).contiguous() # (TK, batch, D)
127
+
128
+ def snn_forward(self, spike_seq: torch.Tensor):
129
+ """SNN 核心:spike_seq → (h_out, ponder_cost)。
130
+
131
+ 纯 SNN 层计算,带梯度检查点。
132
+ 每层返回 (h, ponder_cost),ponder_cost 作为 checkpoint 输出保留梯度图。
133
+
134
+ Returns:
135
+ h: (seq_len*K, batch, D), 连续值
136
+ total_ponder_cost: scalar, 所有层平均期望步数
137
+ """
138
+ h = spike_seq
139
+ ponder_costs = []
140
+
141
+ def _layer_forward(layer_mod, x):
142
+ functional.reset_net(layer_mod)
143
+ return layer_mod.forward_parallel(x) # returns (h, ponder_cost)
144
+
145
+ for layer_module in self.layers:
146
+ h, pc = checkpoint(
147
+ _layer_forward, layer_module, h,
148
+ use_reentrant=False,
149
+ )
150
+ ponder_costs.append(pc)
151
+
152
+ total_ponder_cost = sum(ponder_costs) / len(ponder_costs)
153
+ return h, total_ponder_cost
154
+
155
+ def _output_neuron_parallel(self, h: torch.Tensor) -> torch.Tensor:
156
+ """输出 PLIF 神经元的 parallel scan 前向:连续 h → 膜电位泄漏量。
157
+
158
+ Args:
159
+ h: (TK, batch, D) 连续值(SNN 最后一层输出)
160
+
161
+ Returns:
162
+ leak: (TK, batch, D) 膜电位泄漏量 (1-β)·V_post
163
+ """
164
+ TK, batch, D = h.shape
165
+
166
+ beta = self.output_neuron.beta # (D,)
167
+ u = (1.0 - beta) * h # PLIF: u = (1-β) · x
168
+
169
+ v_init = self.output_neuron.v
170
+ if isinstance(v_init, float):
171
+ v_init = torch.zeros(batch, D, device=h.device, dtype=h.dtype)
172
+
173
+ beta_row = beta.unsqueeze(0).expand(batch, D).contiguous()
174
+ v_th_row = self.output_neuron.v_th.unsqueeze(0).expand(batch, D).contiguous()
175
+
176
+ spike, V_post = plif_rowparam_forward(
177
+ beta_row, u, v_th_row, v_init,
178
+ surrogate_function=self.output_neuron.surrogate_function,
179
+ )
180
+
181
+ self.output_neuron.v = V_post[-1].detach()
182
+ return (1.0 - beta) * V_post # 膜电位泄漏量
183
+
184
+ def decode(self, h_out: torch.Tensor, seq_len: int) -> torch.Tensor:
185
+ """输出边界:连续 h → 输出神经元(V_post) → K 帧聚合 → logits。
186
+
187
+ 梯度流: loss → logits → norm → decode_proj → K帧mean
188
+ → V_post(output_neuron) → h_out → SNN layers
189
+
190
+ Returns: (batch, seq_len, vocab_size)
191
+ """
192
+ h_out = self.output_norm(h_out) # RMSNorm: 控制 scale
193
+ v_out = self._output_neuron_parallel(h_out) # (TK, batch, D), V_post 膜电位
194
+ # K 帧聚合: (TK, batch, D) → (seq_len, K, batch, D) → mean → (seq_len, batch, D)
195
+ decoded = v_out.view(seq_len, self.K, -1, self.D).mean(dim=1)
196
+ decoded = decoded.permute(1, 0, 2) # (batch, seq_len, D)
197
+ h = self.decode_proj(decoded) # (batch, seq_len, D)
198
+ h = self.norm(h) # (batch, seq_len, D)
199
+ return F.linear(h, self.embed_tokens.weight) # (batch, seq_len, vocab)
200
+
201
+ @torch.no_grad()
202
+ def generate(
203
+ self,
204
+ prompt_ids: torch.Tensor,
205
+ max_new_tokens: int,
206
+ temperature: float = 1.0,
207
+ top_k: int = 50,
208
+ eos_token_id: Optional[int] = None,
209
+ ) -> torch.Tensor:
210
+ """
211
+ 自回归生成(SNN 神经元状态跨 token 连续维护)。
212
+
213
+ 1. Prefill: forward_parallel 并行处理 prompt,建立所有神经元 V 状态
214
+ 2. Autoregressive: 逐 token 生成,每 token 用 forward_parallel 处理 K 帧
215
+ 复用 Triton parallel scan kernel,神经元 V 状态跨 token 连续传递
216
+
217
+ Args:
218
+ prompt_ids: (batch, prompt_len) token IDs
219
+ max_new_tokens: 最大生成 token 数
220
+ temperature: 采样温度(<=0 = greedy)
221
+ top_k: top-k 采样(None/0 = 不限制)
222
+ eos_token_id: 遇到此 token 停止生成
223
+
224
+ Returns:
225
+ (batch, prompt_len + generated_len) 完整序列
226
+ """
227
+ batch, prompt_len = prompt_ids.shape
228
+
229
+ # 重置所有神经元(新序列的初始条件 V=0)
230
+ for layer_module in self.layers:
231
+ functional.reset_net(layer_module)
232
+ functional.reset_net(self.output_neuron)
233
+
234
+ # ====== Prefill: parallel 处理整个 prompt ======
235
+ h_seq = self.encode(prompt_ids) # (prompt_len*K, batch, D), 连续值
236
+ h = h_seq
237
+ for layer_module in self.layers:
238
+ h, _ = layer_module.forward_parallel(h) # 推理忽略 ponder_cost
239
+ # 此时所有层的所有神经元 .v 状态 = prompt 末尾状态
240
+
241
+ logits = self.decode(h, prompt_len)
242
+
243
+ # 采样第一个新 token
244
+ next_token = self._sample(logits[:, -1, :], temperature, top_k)
245
+ generated = [next_token]
246
+
247
+ # ====== Autoregressive: 逐 token,forward_parallel 处理 K 帧 ======
248
+ for _ in range(max_new_tokens - 1):
249
+ if eos_token_id is not None and (next_token == eos_token_id).all():
250
+ break
251
+
252
+ # 编码单 token → K 帧连续值(复用 encode)
253
+ frames = self.encode(next_token) # (K, batch, D)
254
+
255
+ # K 帧通过 SNN — 不 reset,神经元 .v 跨 token 连续传递
256
+ h = frames
257
+ for layer_module in self.layers:
258
+ h, _ = layer_module.forward_parallel(h)
259
+
260
+ logits = self.decode(h, 1)
261
+
262
+ next_token = self._sample(logits[:, -1, :], temperature, top_k)
263
+ generated.append(next_token)
264
+
265
+ return torch.cat([prompt_ids, torch.cat(generated, dim=1)], dim=1)
266
+
267
+ def _sample(self, logits: torch.Tensor, temperature: float = 1.0, top_k: int = None) -> torch.Tensor:
268
+ """从 logits 采样(temperature + top-k)。
269
+
270
+ Returns: (batch, 1)
271
+ """
272
+ if temperature <= 0:
273
+ return logits.argmax(dim=-1, keepdim=True)
274
+ logits = logits / temperature
275
+ if top_k is not None and top_k > 0:
276
+ top_k = min(top_k, logits.size(-1))
277
+ v, _ = torch.topk(logits, top_k)
278
+ logits[logits < v[:, [-1]]] = float('-inf')
279
+ probs = F.softmax(logits, dim=-1)
280
+ return torch.multinomial(probs, num_samples=1)
281
+
282
+ def forward(
283
+ self,
284
+ token_ids: torch.Tensor,
285
+ target_ids: torch.Tensor = None,
286
+ ) -> SNNModelOutput:
287
+ """
288
+ 前向传播(全膜电位 + 动态 K)。
289
+
290
+ encode → h_seq # 输入(embed repeat K 次,可微分)
291
+ snn_forward → h_out, pc # SNN 核心(全膜电位 + 动态 K 聚合)
292
+ decode → logits # 输出(V_post → K帧mean → proj → logits)
293
+
294
+ 梯度流:
295
+ embed_tokens → repeat K → SNN layers(V_post + 动态K)
296
+ → output_neuron(V_post) → K帧mean → decode_proj → logits(tied head)
297
+ ponder_cost: 动态 K 正则化,鼓励用更少步数处理简单 token
298
+ """
299
+ batch, seq_len = token_ids.shape
300
+
301
+ # 重置所有神经元状态
302
+ for layer_module in self.layers:
303
+ functional.reset_net(layer_module)
304
+ functional.reset_net(self.output_neuron)
305
+
306
+ # 三段式
307
+ spike_seq = self.encode(token_ids) # 输入边界
308
+ h_out, ponder_cost = self.snn_forward(spike_seq) # SNN 核心 + ponder cost
309
+ logits = self.decode(h_out, seq_len) # 输出边界
310
+
311
+ if target_ids is not None:
312
+ logits_flat = logits.reshape(-1, self.vocab_size)
313
+ targets_flat = target_ids.reshape(-1)
314
+ self.last_loss = F.cross_entropy(
315
+ logits_flat, targets_flat,
316
+ ignore_index=0, reduction='none',
317
+ )
318
+ return SNNModelOutput(
319
+ last_loss=self.last_loss,
320
+ ponder_cost=ponder_cost,
321
+ )
322
+
323
+ return SNNModelOutput(logits=logits, ponder_cost=ponder_cost)
324
+
325
+ def compensate_modulation_gradients(self, max_comp: float = 100.0):
326
+ """
327
+ Natural Gradient 补偿(两阶段)。
328
+
329
+ Phase 1: Sigmoid/softplus 饱和补偿
330
+ β = sigmoid(b_beta), sigmoid 在高 β 区(β=0.99, sigmoid'=0.01)梯度衰减 100x。
331
+ 补偿: grad /= activation'(b),等价于在 β/α 空间做梯度下降。
332
+
333
+ Phase 2: 层间梯度均衡
334
+ 残差链反向传播每层放大 ~1.17×,20 层累积 ~20× L0/L19 比。
335
+ 深层选择性参数(b_beta/b_alpha/b_th)梯度被压制,无法有效学习。
336
+ 修复: 将每层调制参数梯度 norm 归一化到所有层的几何均值。
337
+
338
+ 调用时机: scaler.unscale_(optimizer) 之后、clip_grad_norm_ 之前。
339
+
340
+ Args:
341
+ max_comp: 补偿因子上限(防止极端值导致不稳定)
342
+ """
343
+ # ====== Phase 1: Sigmoid/softplus 饱和补偿 ======
344
+ for layer_module in self.layers:
345
+ block = layer_module.snn_block
346
+
347
+ # b_beta: sigmoid 饱和补偿
348
+ # sigmoid'(z) = sigmoid(z) · (1 - sigmoid(z)) = β · (1-β)
349
+ if block.b_beta.grad is not None:
350
+ with torch.no_grad():
351
+ beta = torch.sigmoid(block.b_beta.data)
352
+ sigmoid_deriv = (beta * (1.0 - beta)).clamp(min=1.0 / max_comp)
353
+ block.b_beta.grad.div_(sigmoid_deriv)
354
+
355
+ # b_alpha: softplus 补偿(较温和,softplus'(z) = sigmoid(z))
356
+ if block.b_alpha.grad is not None:
357
+ with torch.no_grad():
358
+ softplus_deriv = torch.sigmoid(block.b_alpha.data).clamp(min=0.1)
359
+ block.b_alpha.grad.div_(softplus_deriv)
360
+
361
+ # b_th: |·| 导数为 ±1,无衰减,不需要补偿
362
+
363
+ # ====== Phase 2: 层间梯度均衡 ======
364
+ # 残差链 h = h + sublayer(h) 的反向路径 ∂h_{l+1}/∂h_l = I + ∂sublayer/∂h_l
365
+ # 每层放大 ~1.17×, 20 层累积 ~20× → L0 梯度远大于 L19
366
+ # 用几何均值归一化每层调制参数梯度 norm,消除残差放大效应
367
+ with torch.no_grad():
368
+ for param_name in ['b_beta', 'b_alpha', 'b_th']:
369
+ norms = []
370
+ params_list = []
371
+ for layer_module in self.layers:
372
+ p = getattr(layer_module.snn_block, param_name)
373
+ if p.grad is not None:
374
+ n = p.grad.norm().item()
375
+ if n > 1e-12:
376
+ norms.append(n)
377
+ params_list.append(p)
378
+
379
+ if len(norms) >= 2:
380
+ # 几何均值: exp(mean(log(norms))) — 对数尺度均衡,不受极端值影响
381
+ log_mean = sum(math.log(n) for n in norms) / len(norms)
382
+ geo_mean = math.exp(log_mean)
383
+ for p, n in zip(params_list, norms):
384
+ scale = geo_mean / n
385
+ scale = max(min(scale, max_comp), 1.0 / max_comp)
386
+ p.grad.mul_(scale)
387
+
388
+ def get_param_groups(self) -> dict[str, list[nn.Parameter]]:
389
+ """
390
+ 按功能分组的可训练参数。
391
+ """
392
+ groups = {
393
+ 'embedding': [self.embed_tokens.weight],
394
+ 'norm': [self.norm.gain],
395
+ 'decode': list(self.decode_proj.parameters()),
396
+ # 输出神经元
397
+ 'output_neuron': [self.output_neuron.w, self.output_neuron.v_th],
398
+ # RMSNorm(Pre-LN 分支归一化)
399
+ 'rms_norms': [self.output_norm.weight],
400
+ # 残差流组件
401
+ 'residual_projs': [],
402
+ 'input_neurons': [],
403
+ # 动态 K: 停止投影
404
+ 'halt_projs': [],
405
+ # SNNBlock 参数
406
+ 'W_in': [],
407
+ 'W_beta': [],
408
+ 'W_alpha': [],
409
+ 'W_th': [],
410
+ 'W_gate': [],
411
+ 'W_skip': [],
412
+ 'W_out': [],
413
+ 'b_beta': [],
414
+ 'b_alpha': [],
415
+ 'b_th': [],
416
+ 'block_output_neuron': [],
417
+ # SNNFFN 参数
418
+ 'ffn_gate_proj': [],
419
+ 'ffn_up_proj': [],
420
+ 'ffn_down_proj': [],
421
+ 'ffn_skip_proj': [],
422
+ 'ffn_neurons': [],
423
+ }
424
+
425
+ for layer_module in self.layers:
426
+ block = layer_module.snn_block
427
+ ffn = layer_module.snn_ffn
428
+
429
+ # 残差流组件
430
+ groups['residual_projs'].extend([
431
+ layer_module.block_out_proj.weight,
432
+ layer_module.ffn_out_proj.weight,
433
+ ])
434
+ groups['input_neurons'].extend([
435
+ layer_module.input_neuron1.w,
436
+ layer_module.input_neuron1.v_th,
437
+ layer_module.input_neuron2.w,
438
+ layer_module.input_neuron2.v_th,
439
+ ])
440
+ groups['rms_norms'].extend([
441
+ layer_module.block_norm.weight,
442
+ layer_module.ffn_norm.weight,
443
+ ])
444
+
445
+ # 动态 K: 停止投影参数
446
+ groups['halt_projs'].extend(list(layer_module.block_halt.parameters()))
447
+ groups['halt_projs'].extend(list(layer_module.ffn_halt.parameters()))
448
+
449
+ # SNNBlock 参数
450
+ groups['W_in'].append(block.W_in.weight)
451
+ groups['W_beta'].extend([block.W_beta_x.weight])
452
+ groups['W_alpha'].extend([block.W_alpha_x.weight])
453
+ groups['W_th'].extend([block.W_th_x.weight])
454
+ groups['W_gate'].append(block.W_gate.weight)
455
+ groups['W_skip'].append(block.W_skip.weight)
456
+ groups['W_out'].append(block.W_out.weight)
457
+ groups['b_beta'].append(block.b_beta)
458
+ groups['b_alpha'].append(block.b_alpha)
459
+ groups['b_th'].append(block.b_th)
460
+
461
+ # SNNFFN 参数
462
+ groups['ffn_gate_proj'].append(ffn.gate_proj.weight)
463
+ groups['ffn_up_proj'].append(ffn.up_proj.weight)
464
+ groups['ffn_down_proj'].append(ffn.down_proj.weight)
465
+ groups['ffn_skip_proj'].append(ffn.skip_proj.weight)
466
+ groups['ffn_neurons'].extend([
467
+ ffn.gate_neuron.w, ffn.gate_neuron.v_th,
468
+ ffn.up_neuron.w, ffn.up_neuron.v_th,
469
+ ])
470
+
471
+ return groups
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85a3b5312303134fced86a9d352175b91682176b6dea6af176f64bdaf6fc4b57
3
+ size 3496634368
modeling_neuronspark.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NeuronSpark: SNN 隐状态空间语言模型 — HuggingFace 接口
3
+
4
+ 用法:
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ "checkpoints_sft/", trust_remote_code=True,
9
+ )
10
+ tokenizer = AutoTokenizer.from_pretrained("checkpoints_sft/")
11
+ """
12
+
13
+ from typing import Optional
14
+
15
+ import torch
16
+ from transformers import PreTrainedModel, GenerationMixin
17
+ from transformers.modeling_outputs import CausalLMOutputWithPast
18
+
19
+ from configuration_neuronspark import NeuronSparkConfig
20
+ from model import SNNLanguageModel
21
+
22
+
23
+ class NeuronSparkForCausalLM(PreTrainedModel, GenerationMixin):
24
+ """
25
+ SNN 语言模型 — CausalLM 接口。
26
+
27
+ 封装 SNNLanguageModel,提供 HuggingFace 标准接口:
28
+ - forward(input_ids, labels) → CausalLMOutputWithPast
29
+ - generate() 支持(通过 GenerationMixin)
30
+ """
31
+ config_class = NeuronSparkConfig
32
+ supports_gradient_checkpointing = True
33
+
34
+ def __init__(self, config: NeuronSparkConfig):
35
+ super().__init__(config)
36
+ self.model = SNNLanguageModel(
37
+ vocab_size=config.vocab_size,
38
+ D=config.D,
39
+ N=config.N,
40
+ K=config.K,
41
+ num_layers=config.num_layers,
42
+ D_ff=config.D_ff,
43
+ v_th_min=config.v_th_min,
44
+ )
45
+
46
+ def get_input_embeddings(self):
47
+ return self.model.embed_tokens
48
+
49
+ def set_input_embeddings(self, value):
50
+ self.model.embed_tokens = value
51
+
52
+ def get_output_embeddings(self):
53
+ # tied head: 输出复用 embed_tokens.weight
54
+ return self.model.embed_tokens
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.Tensor,
59
+ labels: Optional[torch.Tensor] = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ **kwargs,
62
+ ) -> CausalLMOutputWithPast:
63
+ """
64
+ 前向传播。
65
+
66
+ Args:
67
+ input_ids: (batch, seq_len) token IDs
68
+ labels: (batch, seq_len) 目标 token IDs(可选,用于计算 loss)
69
+ attention_mask: 兼容参数(SNN 无 attention,忽略)
70
+ """
71
+ if labels is not None:
72
+ out = self.model(input_ids, target_ids=labels)
73
+ # 计算 masked loss
74
+ loss_mask = (labels != 0).float().view(-1)
75
+ loss = (out.last_loss * loss_mask).sum() / loss_mask.sum()
76
+ # 加 ponder cost
77
+ if out.ponder_cost is not None:
78
+ loss = loss + 0.01 * out.ponder_cost
79
+ return CausalLMOutputWithPast(loss=loss)
80
+ else:
81
+ out = self.model(input_ids)
82
+ return CausalLMOutputWithPast(logits=out.logits)
83
+
84
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
85
+ """generate() 所需的输入准备。"""
86
+ return {"input_ids": input_ids}
87
+
88
+ @torch.no_grad()
89
+ def generate(
90
+ self,
91
+ input_ids: torch.Tensor,
92
+ max_new_tokens: int = 256,
93
+ temperature: float = 1.0,
94
+ top_k: int = 50,
95
+ eos_token_id: Optional[int] = None,
96
+ **kwargs,
97
+ ) -> torch.Tensor:
98
+ """
99
+ 自回归生成(直接调用 SNN 的 generate 方法)。
100
+ """
101
+ return self.model.generate(
102
+ prompt_ids=input_ids,
103
+ max_new_tokens=max_new_tokens,
104
+ temperature=temperature,
105
+ top_k=top_k,
106
+ eos_token_id=eos_token_id,
107
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|im_start|>",
3
+ "eos_token": "<|im_end|>",
4
+ "unk_token": "<unk>",
5
+ "pad_token": "<|im_end|>",
6
+ "additional_special_tokens": [
7
+ "<s>",
8
+ "</s>"
9
+ ]
10
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "bos_token": "<|im_start|>",
6
+ "eos_token": "<|im_end|>",
7
+ "pad_token": "<|im_end|>",
8
+ "unk_token": "<unk>",
9
+ "model_max_length": 1000000000000000019884624838656,
10
+ "clean_up_tokenization_spaces": false,
11
+ "tokenizer_class": "PreTrainedTokenizerFast",
12
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
13
+ }