Brain2nd commited on
Commit
440e322
·
verified ·
1 Parent(s): 87750f6

Initial release: NeuronSpark-0.9B-Chat instruction-tuned SNN language model

Browse files
LICENSE ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to the Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by the Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding any notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ Copyright 2025 Zhengzheng Tang
179
+
180
+ Licensed under the Apache License, Version 2.0 (the "License");
181
+ you may not use this file except in compliance with the License.
182
+ You may obtain a copy of the License at
183
+
184
+ http://www.apache.org/licenses/LICENSE-2.0
185
+
186
+ Unless required by applicable law or agreed to in writing, software
187
+ distributed under the License is distributed on an "AS IS" BASIS,
188
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
189
+ See the License for the specific language governing permissions and
190
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - zh
5
+ base_model: Brain2nd/NeuronSpark-0.9B
6
+ library_name: transformers
7
+ tags:
8
+ - snn
9
+ - spiking-neural-network
10
+ - text-generation
11
+ - neuromorphic
12
+ - chat
13
+ pipeline_tag: text-generation
14
+ ---
15
+ # NeuronSpark-0.9B-Chat
16
+
17
+ ## Introduction
18
+
19
+ **NeuronSpark-0.9B-Chat** is the **instruction-tuned chat version** of NeuronSpark-0.9B — a 0.87-billion parameter language model built entirely on Spiking Neural Networks (SNNs). It has been fine-tuned on a small subset of BelleGroup 3.5M Chinese instructions to enable basic dialogue capabilities.
20
+
21
+ > **Note on training data**: Due to limited compute resources, both pretraining and SFT used only **small subsets** of their respective datasets (pretrain ~1.4B of ~10B tokens; SFT ~6.5K steps of ~3.5M samples). Despite this minimal data budget, the model demonstrates coherent Chinese dialogue — validating that pure SNN architectures can learn language from scratch. We plan to scale training with more data and compute in future work.
22
+
23
+ For the pretrained base model, see [NeuronSpark-0.9B](https://modelscope.cn/models/Brain2nd/NeuronSpark-0.9B).
24
+
25
+ ## Model Details
26
+
27
+ | Attribute | Value |
28
+ |-----------|-------|
29
+ | Parameters | 874M |
30
+ | Architecture | SNN Hidden State Space Model |
31
+ | Hidden Dimension (D) | 896 |
32
+ | Layers | 20 |
33
+ | SNN Timesteps (K) | 16 (PonderNet adaptive) |
34
+ | State Expansion (N) | 8 |
35
+ | FFN Dimension | 2688 |
36
+ | Vocabulary | 6144 (custom BPE) |
37
+ | Context Length | 512 tokens |
38
+ | Base Model | NeuronSpark-0.9B (pretrained 85K steps) |
39
+ | SFT Data | BelleGroup train_3.5M_CN |
40
+ | SFT Steps | 6,500 |
41
+ | Chat Template | ChatML |
42
+ | License | Apache 2.0 |
43
+
44
+ ## Quickstart
45
+
46
+ ```python
47
+ from transformers import AutoModelForCausalLM, AutoTokenizer
48
+
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ "Brain2nd/NeuronSpark-0.9B-Chat",
51
+ trust_remote_code=True,
52
+ )
53
+ tokenizer = AutoTokenizer.from_pretrained("Brain2nd/NeuronSpark-0.9B-Chat")
54
+
55
+ # Chat
56
+ messages = [
57
+ {"role": "system", "content": "你是一个AI助手"},
58
+ {"role": "user", "content": "中国的首都是哪里?"},
59
+ ]
60
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
+ input_ids = tokenizer(text, return_tensors="pt")["input_ids"]
62
+
63
+ output_ids = model.generate(
64
+ input_ids,
65
+ max_new_tokens=256,
66
+ temperature=0.1,
67
+ top_k=10,
68
+ eos_token_id=tokenizer.eos_token_id,
69
+ )
70
+
71
+ # Extract assistant response
72
+ full_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
73
+ response = full_text.split("assistant\n")[-1].replace("<|im_end|>", "").strip()
74
+ print(response)
75
+ ```
76
+
77
+ **Example Output:**
78
+ ```
79
+ Q: 中国的首都是哪里?
80
+ A: 中国的首都在北京。
81
+
82
+ Q: 你好呀
83
+ A: 请问您需要什么样的帮助?
84
+ ```
85
+
86
+ ## Architecture Highlights
87
+
88
+ - **Pure SNN**: No attention, no standard MLP — all computation via PLIF (Parametric Leaky Integrate-and-Fire) neurons
89
+ - **Membrane Potential Leakage Activation**: PLIFNode outputs `(1-β)·V_post` (leak current), naturally emphasizing fast-responding neurons
90
+ - **Selective State Space**: Hidden neurons with input-dependent dynamic β(t), α(t), V_th(t)
91
+ - **PonderNet Adaptive K**: Each token dynamically decides how many SNN timesteps to use
92
+ - **Triton Fused Kernels**: Custom PLIF forward/backward kernels for efficient parallel scan
93
+ - **ChatML Template**: Compatible with standard chat formatting
94
+
95
+ ## Requirements
96
+
97
+ ```bash
98
+ pip install torch transformers spikingjelly safetensors
99
+ ```
100
+
101
+ ## Limitations
102
+
103
+ - **Context length**: 512 tokens (limited by training configuration)
104
+ - **Knowledge**: Trained on Chinese corpus only; limited factual accuracy
105
+ - **Repetition**: May generate repetitive text for complex queries
106
+ - **Scale**: 0.9B parameters — significantly smaller than state-of-the-art chat models
107
+
108
+ This is a **research model** demonstrating that SNN architectures can achieve basic language understanding and dialogue, even with very limited training data. It is not intended for production use. We plan to continue scaling with more data and compute.
109
+
110
+ ## Citation
111
+
112
+ ```bibtex
113
+ @misc{neuronspark2025,
114
+ title={NeuronSpark: A Spiking Neural Network Language Model with Selective State Space Dynamics},
115
+ author={Zhengzheng Tang},
116
+ year={2025},
117
+ url={https://github.com/Brain2nd/NeuronSpark}
118
+ }
119
+ ```
120
+
121
+ ## Contact
122
+
123
+ - **Author**: Zhengzheng Tang
124
+ - **Email**: zztangbu@bu.edu
125
+ - **GitHub**: [Brain2nd/NeuronSpark](https://github.com/Brain2nd/NeuronSpark)
126
+
atomic_ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .selective_plif import SelectivePLIFNode
2
+ from .plif_node import PLIFNode
3
+ from .lateral_inhibition import LateralInhibition
4
+ from .snn_block import SNNBlock
5
+ from .snn_ffn import SNNFFN
6
+ from .snn_decoder_layer import SNNDecoderLayer
7
+ from .parallel_scan import hillis_steele_scan, linear_recurrence, plif_parallel_forward
8
+ from .fp16_codec import fp16_encode, fp16_decode
9
+ from .rms_norm import RMSNorm
atomic_ops/fp16_codec.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FP16 二进制编码/解码 — 模型边界操作(无可训练参数)。
3
+
4
+ IEEE 754 float16 位布局(K=16 时间步):
5
+ 时间步: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
6
+ 位: sign E4 E3 E2 E1 E0 M9 M8 M7 M6 M5 M4 M3 M2 M1 M0
7
+ 含义: 符号 ←── 指数(bias=15) ──→ ←────────── 尾数(隐含 1.xxx) ──────────→
8
+
9
+ 编码: embedding → IEEE 754 float16 位提取 → 16 帧二值 spike(detach,固定预处理)
10
+ 解码: 16 帧二值 spike → IEEE 754 位重建 → 连续值(可微分,梯度通过 surrogate grad 传播)
11
+ """
12
+
13
+ import torch
14
+ from torch import Tensor
15
+
16
+
17
+ def fp16_encode(emb: Tensor, K: int = 16) -> Tensor:
18
+ """FP16 二进制编码(模型边界操作,固定预处理)。
19
+
20
+ 将连续 embedding 转为 IEEE 754 float16 位模式,作为 SNN 的 spike 输入。
21
+
22
+ Args:
23
+ emb: (batch, seq_len, D) 连续 embedding
24
+ K: 时间步数(必须为 16,对应 float16 的 16 位)
25
+
26
+ Returns:
27
+ spike_seq: (seq_len*K, batch, D) 二值 {0, 1}, detached
28
+ """
29
+ batch, seq_len, D = emb.shape
30
+
31
+ # 转为 float16 获取 IEEE 754 位模式
32
+ # clamp 防止 overflow 产生 Inf(float16 最大值 65504)
33
+ emb_fp16 = emb.float().clamp(-65504.0, 65504.0).half()
34
+ bits_int = emb_fp16.view(torch.int16) # (batch, seq_len, D)
35
+
36
+ # 提取 16 位(MSB first: sign, exponent, mantissa)
37
+ shifts = torch.arange(15, -1, -1, device=emb.device) # [15, 14, ..., 0]
38
+ # bits_int: (batch, seq_len, D) → unsqueeze → (batch, seq_len, 1, D)
39
+ # shifts: (K,) → view → (1, 1, K, 1)
40
+ bits = ((bits_int.unsqueeze(2) >> shifts.view(1, 1, K, 1)) & 1) # (batch, seq_len, K, D)
41
+
42
+ # 转为计算 dtype 并 detach(编码不参与梯度)
43
+ bits = bits.to(emb.dtype).detach()
44
+
45
+ # reshape → (seq_len*K, batch, D)
46
+ return bits.reshape(batch, seq_len * K, D).permute(1, 0, 2).contiguous()
47
+
48
+
49
+ def fp16_decode(spikes: Tensor, seq_len: int, K: int = 16) -> Tensor:
50
+ """FP16 精确位解码:从 16 个二值 spike 重建 float16 值。
51
+
52
+ fp16_encode 的精确逆操作。全程可微分——梯度通过 IEEE 754 重建公式
53
+ 传到每个 spike 输出,再经 surrogate gradient 传入 SNN。
54
+
55
+ IEEE 754 float16 重建:
56
+ Normal (exp > 0): (-1)^sign * 2^(exp - 15) * (1 + mant_frac)
57
+ Subnormal (exp = 0): (-1)^sign * 2^(-14) * mant_frac
58
+ 其中 mant_frac = Σ mant_bit_i * 2^{-(i+1)}, i=0..9
59
+
60
+ Args:
61
+ spikes: (seq_len*K, batch, D) 二值 {0, 1}(输出神经元的 spike)
62
+ seq_len: token 序列长度
63
+ K: 时间步数(= 16)
64
+
65
+ Returns:
66
+ decoded: (batch, seq_len, D) 连续值
67
+ """
68
+ batch, D = spikes.shape[1], spikes.shape[2]
69
+
70
+ # (seq_len*K, batch, D) → (batch, seq_len, K, D)
71
+ s = spikes.permute(1, 0, 2).reshape(batch, seq_len, K, D)
72
+
73
+ # ---- Sign: bit 0 ----
74
+ sign = 1.0 - 2.0 * s[:, :, 0, :] # +1 or -1
75
+
76
+ # ---- Exponent: bits 1-5, 加权求和 → 整数 0~31 ----
77
+ exp_weights = torch.tensor(
78
+ [16.0, 8.0, 4.0, 2.0, 1.0],
79
+ device=spikes.device, dtype=spikes.dtype,
80
+ )
81
+ exp_val = (s[:, :, 1:6, :] * exp_weights.view(1, 1, 5, 1)).sum(dim=2)
82
+
83
+ # ---- Mantissa fraction: bits 6-15, 加权求和 → [0, 1) ----
84
+ mant_weights = torch.tensor(
85
+ [2.0 ** (-i) for i in range(1, 11)],
86
+ device=spikes.device, dtype=spikes.dtype,
87
+ )
88
+ mant_frac = (s[:, :, 6:, :] * mant_weights.view(1, 1, 10, 1)).sum(dim=2)
89
+
90
+ # ---- IEEE 754 重建 ----
91
+ # Normal: (-1)^s * 2^(exp-15) * (1 + mant_frac)
92
+ # Subnormal: (-1)^s * 2^(-14) * mant_frac
93
+ is_normal = (exp_val > 0)
94
+
95
+ normal_val = sign * torch.exp2(exp_val - 15.0) * (1.0 + mant_frac)
96
+ subnormal_val = sign * (2.0 ** -14) * mant_frac
97
+
98
+ return torch.where(is_normal, normal_val, subnormal_val)
atomic_ops/lateral_inhibition.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LateralInhibition: 侧抑制归一化(Divisive Normalization)
3
+
4
+ 神经科学基础:
5
+ Carandini & Heeger (2012) "Normalization as a canonical neural computation"
6
+ 侧抑制是大脑中最基本的计算原语之一:兴奋性神经元的活动通过抑制性中间神经元池
7
+ 反馈调节,实现增益控制(gain control)。
8
+
9
+ SNN 机制:
10
+ 1. 兴奋性群体活动: activity_i = h_i²
11
+ 2. 抑制性中间神经元池: pool = mean(activity) = mean(h²)
12
+ 3. 分裂抑制 (shunting inhibition): h_norm = h / sqrt(pool + ε)
13
+ 4. 增益调制 (gain modulation): output = gain · h_norm
14
+
15
+ 替换 RMSNorm:数学操作等价,但在 SNN 框架中有明确的神经科学对应——
16
+ RMSNorm 是 divisive normalization 的特例。
17
+
18
+ Triton fused kernel:
19
+ - 前向: {mean(h²), rsqrt, element-wise mul} → 1 kernel launch
20
+ - 反向: {recompute norm, grad_gain, grad_h} → 1 kernel launch
21
+ - 每行 (D dim) 一个 block,行间并行
22
+ """
23
+
24
+ import os
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ from spikingjelly.activation_based import base
29
+
30
+
31
+ # ============================================================
32
+ # Triton fused kernels
33
+ # ============================================================
34
+
35
+ _SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas'
36
+ if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ:
37
+ os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS
38
+
39
+ _HAS_TRITON = False
40
+ try:
41
+ import triton
42
+ import triton.language as tl
43
+ _HAS_TRITON = True
44
+ except ImportError:
45
+ pass
46
+
47
+
48
+ if _HAS_TRITON:
49
+
50
+ @triton.jit
51
+ def _li_fwd_kernel(
52
+ X_ptr, GAIN_ptr, OUT_ptr,
53
+ stride_row,
54
+ D: tl.constexpr,
55
+ eps: tl.constexpr,
56
+ BLOCK_D: tl.constexpr,
57
+ ):
58
+ """Forward: out = x * rsqrt(mean(x²) + eps) * gain
59
+
60
+ Grid: (num_rows,). Each program processes one row of D elements.
61
+ Computation in float32; storage in input dtype.
62
+ """
63
+ row = tl.program_id(0)
64
+ cols = tl.arange(0, BLOCK_D)
65
+ mask = cols < D
66
+ off = row * stride_row + cols
67
+
68
+ # Load in float32
69
+ x = tl.load(X_ptr + off, mask=mask, other=0.0).to(tl.float32)
70
+ gain = tl.load(GAIN_ptr + cols, mask=mask, other=0.0).to(tl.float32)
71
+
72
+ # Inhibitory pool: population activity
73
+ variance = tl.sum(x * x, axis=0) / D
74
+ rrms = 1.0 / tl.sqrt(variance + eps)
75
+
76
+ # Divisive inhibition + gain modulation
77
+ out = x * rrms * gain
78
+
79
+ tl.store(OUT_ptr + off, out, mask=mask)
80
+
81
+ @triton.jit
82
+ def _li_bwd_kernel(
83
+ DOUT_ptr, X_ptr, GAIN_ptr,
84
+ DX_ptr, DGAIN_ptr,
85
+ stride_row,
86
+ D: tl.constexpr,
87
+ eps: tl.constexpr,
88
+ BLOCK_D: tl.constexpr,
89
+ ):
90
+ """Backward: grad_x, grad_gain (per-row, reduced externally).
91
+
92
+ Grid: (num_rows,).
93
+ d_x = rrms * (d_out * gain - x_hat * mean(d_out * gain * x_hat))
94
+ d_gain_row = d_out * x_hat (sum across rows done outside kernel)
95
+ """
96
+ row = tl.program_id(0)
97
+ cols = tl.arange(0, BLOCK_D)
98
+ mask = cols < D
99
+ off = row * stride_row + cols
100
+
101
+ dout = tl.load(DOUT_ptr + off, mask=mask, other=0.0).to(tl.float32)
102
+ x = tl.load(X_ptr + off, mask=mask, other=0.0).to(tl.float32)
103
+ gain = tl.load(GAIN_ptr + cols, mask=mask, other=0.0).to(tl.float32)
104
+
105
+ # Recompute forward (avoid saving intermediate tensors)
106
+ variance = tl.sum(x * x, axis=0) / D
107
+ rrms = 1.0 / tl.sqrt(variance + eps)
108
+ x_hat = x * rrms
109
+
110
+ # grad_gain (per-row contribution)
111
+ dgain = dout * x_hat
112
+ tl.store(DGAIN_ptr + off, dgain, mask=mask)
113
+
114
+ # grad_x: rrms * (dout*gain - x_hat * mean(dout*gain*x_hat))
115
+ dout_gain = dout * gain
116
+ dot = tl.sum(dout_gain * x_hat, axis=0) / D
117
+ dx = (dout_gain - x_hat * dot) * rrms
118
+
119
+ tl.store(DX_ptr + off, dx, mask=mask)
120
+
121
+
122
+ class _LateralInhibitionTriton(torch.autograd.Function):
123
+ """Triton-accelerated lateral inhibition (divisive normalization)."""
124
+
125
+ @staticmethod
126
+ def forward(ctx, x, gain, eps):
127
+ orig_shape = x.shape
128
+ D = x.shape[-1]
129
+ x_2d = x.reshape(-1, D).contiguous()
130
+ N = x_2d.shape[0]
131
+
132
+ out = torch.empty_like(x_2d)
133
+
134
+ BLOCK_D = triton.next_power_of_2(D)
135
+ _li_fwd_kernel[(N,)](
136
+ x_2d, gain, out,
137
+ x_2d.stride(0),
138
+ D=D, eps=eps, BLOCK_D=BLOCK_D,
139
+ )
140
+
141
+ ctx.save_for_backward(x_2d, gain)
142
+ ctx.eps = eps
143
+ ctx.orig_shape = orig_shape
144
+ ctx.N = N
145
+ ctx.D = D
146
+
147
+ return out.reshape(orig_shape)
148
+
149
+ @staticmethod
150
+ def backward(ctx, grad_output):
151
+ x_2d, gain = ctx.saved_tensors
152
+ D = ctx.D
153
+ N = ctx.N
154
+
155
+ grad_2d = grad_output.reshape(N, D).contiguous()
156
+
157
+ dx = torch.empty_like(x_2d)
158
+ dgain_rows = torch.empty_like(x_2d)
159
+
160
+ BLOCK_D = triton.next_power_of_2(D)
161
+ _li_bwd_kernel[(N,)](
162
+ grad_2d, x_2d, gain,
163
+ dx, dgain_rows,
164
+ x_2d.stride(0),
165
+ D=D, eps=ctx.eps, BLOCK_D=BLOCK_D,
166
+ )
167
+
168
+ # Reduce per-row dgain across all rows
169
+ dgain = dgain_rows.sum(dim=0)
170
+
171
+ return dx.reshape(ctx.orig_shape), dgain, None
172
+
173
+
174
+ # ============================================================
175
+ # Public module
176
+ # ============================================================
177
+
178
+ class LateralInhibition(base.MemoryModule):
179
+ """
180
+ 侧抑制归一化层(Divisive Normalization)。
181
+
182
+ 通过抑制性中间神经元池实现增益控制。
183
+
184
+ 数学:
185
+ pool = mean(h², dim=-1) # 抑制性池:群体活动水平
186
+ h_norm = h / sqrt(pool + ε) # 分裂抑制 (shunting inhibition)
187
+ output = gain · h_norm # 增益调制 (gain modulation)
188
+
189
+ 等价于 RMSNorm,但在 SNN 框架中对应 divisive normalization
190
+ (Carandini & Heeger, 2012),是神经科学中最基本的计算原语之一。
191
+
192
+ CUDA: Triton fused kernel(前向+反向各 1 次 launch)
193
+ CPU: PyTorch fallback
194
+
195
+ Args:
196
+ dim: 特征维度(D)
197
+ eps: 数值稳定性
198
+ """
199
+
200
+ def __init__(self, dim: int, eps: float = 1e-6):
201
+ super().__init__()
202
+ self.gain = nn.Parameter(torch.ones(dim))
203
+ self.eps = eps
204
+ self.dim = dim
205
+
206
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
207
+ if _HAS_TRITON and h.is_cuda:
208
+ return _LateralInhibitionTriton.apply(h, self.gain, self.eps)
209
+ # PyTorch fallback
210
+ variance = h.pow(2).mean(-1, keepdim=True)
211
+ h_norm = h * torch.rsqrt(variance + self.eps)
212
+ return self.gain * h_norm
213
+
214
+ def extra_repr(self):
215
+ return f'dim={self.dim}, eps={self.eps}'
atomic_ops/parallel_scan.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parallel Scan 工具函数:SNN 线性递推的高效并行求解
3
+
4
+ 实现三层后端:
5
+ 1. Fused PLIF kernel(默认,CUDA + Sigmoid surrogate):
6
+ 单 kernel 完成 PLIF 前向(scan + spike + soft reset)和反向(surrogate gradient)
7
+ · per-element beta/v_th: _fused_plif_fwd_kernel / _fused_plif_bwd_kernel
8
+ · row-param beta/v_th: _fused_plif_fwd_rowparam_kernel / _fused_plif_bwd_rowparam_kernel
9
+ 2. Triton linear_recurrence(CUDA,非 Sigmoid 或无 surrogate):
10
+ 列级并行 scan,O(K) 工作量,1 次 kernel launch
11
+ 3. Hillis-Steele parallel scan(CPU 回退):O(K log K) 工作量
12
+
13
+ 线性递推:
14
+ V[k] = a[k] * V[k-1] + b[k], V[-1] = v_init
15
+
16
+ PLIF 神经元动力学:
17
+ V_pre[k] = beta[k] * V_post[k-1] + u[k]
18
+ s[k] = Θ(V_pre[k] - v_th[k])
19
+ V_post[k] = V_pre[k] - v_th[k] * s[k]
20
+
21
+ 数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
22
+ """
23
+
24
+ import os
25
+ import torch
26
+
27
+
28
+ # ============================================================
29
+ # Triton fused recurrence kernels
30
+ # ============================================================
31
+
32
+ # DGX Spark (GB10, sm_121a): Triton 3.5.1 自带 ptxas 不支持 sm_121a,
33
+ # 需要使用系统 CUDA 13.0 的 ptxas
34
+ _SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas'
35
+ if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ:
36
+ os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS
37
+
38
+ _HAS_TRITON = False
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+ _HAS_TRITON = True
43
+ except ImportError:
44
+ pass
45
+
46
+ if _HAS_TRITON:
47
+
48
+ @triton.jit
49
+ def _fwd_recurrence_kernel(
50
+ A_ptr, B_ptr, INIT_ptr, OUT_ptr,
51
+ K, num_cols,
52
+ BLOCK: tl.constexpr,
53
+ ):
54
+ """Forward: V[k] = A[k]*V[k-1] + B[k], V[-1] = init.
55
+
56
+ Grid: (ceil(num_cols / BLOCK),)
57
+ Each program processes BLOCK columns across all K sequential steps.
58
+ Accumulation in float32; storage in input dtype.
59
+ """
60
+ pid = tl.program_id(0)
61
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
62
+ mask = cols < num_cols
63
+
64
+ v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
65
+
66
+ for k in range(K):
67
+ off = k * num_cols + cols
68
+ a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32)
69
+ b = tl.load(B_ptr + off, mask=mask, other=0.0).to(tl.float32)
70
+ v = a * v + b
71
+ tl.store(OUT_ptr + off, v, mask=mask)
72
+
73
+ @triton.jit
74
+ def _bwd_recurrence_kernel(
75
+ A_ptr, V_ptr, INIT_ptr, GRAD_OUT_ptr,
76
+ GRAD_A_ptr, GRAD_B_ptr, GRAD_INIT_ptr,
77
+ K, num_cols,
78
+ BLOCK: tl.constexpr,
79
+ ):
80
+ """Backward for V[k] = A[k]*V[k-1] + B[k].
81
+
82
+ Reverse accumulation (k from K-1 down to 0):
83
+ g = 0
84
+ for k = K-1, ..., 0:
85
+ g += grad_out[k]
86
+ grad_B[k] = g
87
+ grad_A[k] = g * V[k-1] (V[-1] = init)
88
+ g = g * A[k]
89
+ grad_init = g
90
+ """
91
+ pid = tl.program_id(0)
92
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
93
+ mask = cols < num_cols
94
+
95
+ g = tl.zeros([BLOCK], dtype=tl.float32)
96
+
97
+ for k_rev in range(K):
98
+ k = K - 1 - k_rev
99
+ off = k * num_cols + cols
100
+
101
+ dV = tl.load(GRAD_OUT_ptr + off, mask=mask, other=0.0).to(tl.float32)
102
+ g = g + dV
103
+
104
+ tl.store(GRAD_B_ptr + off, g, mask=mask)
105
+
106
+ if k > 0:
107
+ v_prev = tl.load(
108
+ V_ptr + (k - 1) * num_cols + cols,
109
+ mask=mask, other=0.0,
110
+ ).to(tl.float32)
111
+ else:
112
+ v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
113
+ tl.store(GRAD_A_ptr + off, g * v_prev, mask=mask)
114
+
115
+ a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32)
116
+ g = g * a
117
+
118
+ tl.store(GRAD_INIT_ptr + cols, g, mask=mask)
119
+
120
+ class _TritonLinearRecurrence(torch.autograd.Function):
121
+ """Fused Triton linear recurrence: V[k] = A[k]*V[k-1] + B[k]."""
122
+
123
+ _BLOCK = 128
124
+
125
+ @staticmethod
126
+ def forward(ctx, beta, u, v_init):
127
+ beta_c = beta.contiguous()
128
+ u_c = u.contiguous()
129
+ v_init_c = v_init.contiguous()
130
+
131
+ K = beta_c.shape[0]
132
+ num_cols = beta_c[0].numel()
133
+ V = torch.empty_like(u_c)
134
+
135
+ BLOCK = _TritonLinearRecurrence._BLOCK
136
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
137
+
138
+ _fwd_recurrence_kernel[grid](
139
+ beta_c, u_c, v_init_c, V,
140
+ K, num_cols,
141
+ BLOCK=BLOCK,
142
+ )
143
+
144
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
145
+ ctx.save_for_backward(beta_c, V, v_init_c)
146
+ ctx.K = K
147
+ ctx.num_cols = num_cols
148
+
149
+ return V
150
+
151
+ @staticmethod
152
+ def backward(ctx, grad_V):
153
+ beta, V, v_init = ctx.saved_tensors
154
+ grad_V_c = grad_V.contiguous()
155
+
156
+ K = ctx.K
157
+ num_cols = ctx.num_cols
158
+
159
+ grad_beta = torch.empty_like(beta)
160
+ grad_u = torch.empty_like(beta)
161
+ grad_v_init = torch.empty_like(v_init)
162
+
163
+ BLOCK = _TritonLinearRecurrence._BLOCK
164
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
165
+
166
+ _bwd_recurrence_kernel[grid](
167
+ beta, V, v_init, grad_V_c,
168
+ grad_beta, grad_u, grad_v_init,
169
+ K, num_cols,
170
+ BLOCK=BLOCK,
171
+ )
172
+
173
+ return grad_beta, grad_u, grad_v_init
174
+
175
+ # ============================================================
176
+ # Fused PLIF forward/backward kernels
177
+ # ============================================================
178
+
179
+ @triton.jit
180
+ def _fused_plif_fwd_kernel(
181
+ BETA_ptr, U_ptr, VTH_ptr, INIT_ptr,
182
+ SPIKE_ptr, VPOST_ptr,
183
+ K, num_cols,
184
+ BLOCK: tl.constexpr,
185
+ ):
186
+ """Fused PLIF forward: single-pass sequential scan with inline spike + soft reset.
187
+
188
+ Exact computation — sequential scan IS the ground truth.
189
+ Replaces the 3-phase approach (linear scan + spike iteration + correction).
190
+
191
+ Per column (parallel across batch*D):
192
+ v = v_init
193
+ for k = 0..K-1:
194
+ v_pre = beta[k]*v + u[k]
195
+ spike[k] = Θ(v_pre - v_th[k])
196
+ v = v_pre - v_th[k]*spike[k]
197
+ """
198
+ pid = tl.program_id(0)
199
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
200
+ mask = cols < num_cols
201
+
202
+ v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
203
+
204
+ for k in range(K):
205
+ off = k * num_cols + cols
206
+ beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32)
207
+ u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32)
208
+ vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32)
209
+
210
+ v_pre = beta * v + u
211
+ spike = tl.where(v_pre >= vth, 1.0, 0.0)
212
+ v = v_pre - vth * spike # soft reset
213
+
214
+ tl.store(SPIKE_ptr + off, spike, mask=mask)
215
+ tl.store(VPOST_ptr + off, v, mask=mask)
216
+
217
+ @triton.jit
218
+ def _fused_plif_bwd_kernel(
219
+ BETA_ptr, VTH_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr,
220
+ GRAD_SPIKE_ptr, GRAD_VPOST_ptr,
221
+ GRAD_BETA_ptr, GRAD_U_ptr, GRAD_VTH_ptr, GRAD_INIT_ptr,
222
+ K, num_cols, ALPHA,
223
+ BLOCK: tl.constexpr,
224
+ ):
225
+ """Fused PLIF backward: single reverse pass with Sigmoid surrogate gradient.
226
+
227
+ V_pre[k] = V_post[k] + v_th[k]*spike[k] (reconstructed)
228
+ surrogate_grad(x) = alpha * sigmoid(alpha*x) * (1 - sigmoid(alpha*x))
229
+ where x = V_pre[k] - v_th[k] = V_post[k] - v_th[k]*(1 - spike[k])
230
+
231
+ Reverse accumulation:
232
+ acc = 0
233
+ for k = K-1 downto 0:
234
+ total_gV = grad_V_post[k] + acc
235
+ sg = surrogate_grad(V_pre[k] - v_th[k])
236
+ grad_v_pre = grad_spike[k]*sg + total_gV
237
+ grad_beta[k] = grad_v_pre * V_post[k-1]
238
+ grad_u[k] = grad_v_pre
239
+ grad_v_th[k] = -grad_spike[k]*sg - total_gV*spike[k]
240
+ acc = grad_v_pre * beta[k]
241
+ grad_v_init = acc
242
+ """
243
+ pid = tl.program_id(0)
244
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
245
+ mask = cols < num_cols
246
+
247
+ acc = tl.zeros([BLOCK], dtype=tl.float32)
248
+
249
+ for k_rev in range(K):
250
+ k = K - 1 - k_rev
251
+ off = k * num_cols + cols
252
+
253
+ beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32)
254
+ vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32)
255
+ v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
256
+ spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
257
+
258
+ g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
259
+ g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
260
+
261
+ # V_post[k-1]
262
+ if k > 0:
263
+ v_prev = tl.load(
264
+ VPOST_ptr + (k - 1) * num_cols + cols,
265
+ mask=mask, other=0.0,
266
+ ).to(tl.float32)
267
+ else:
268
+ v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
269
+
270
+ # Sigmoid surrogate gradient
271
+ x = v_post - vth * (1.0 - spike) # = V_pre - v_th
272
+ neg_ax = -ALPHA * x
273
+ neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax) # prevent exp overflow
274
+ sig = 1.0 / (1.0 + tl.exp(neg_ax))
275
+ sg = ALPHA * sig * (1.0 - sig)
276
+
277
+ total_gV = g_V + acc
278
+ grad_v_pre = g_s * sg + total_gV
279
+
280
+ tl.store(GRAD_BETA_ptr + off, grad_v_pre * v_prev, mask=mask)
281
+ tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask)
282
+ tl.store(GRAD_VTH_ptr + off, -g_s * sg - total_gV * spike, mask=mask)
283
+
284
+ acc = grad_v_pre * beta
285
+
286
+ tl.store(GRAD_INIT_ptr + cols, acc, mask=mask)
287
+
288
+ # ============================================================
289
+ # Fused PLIF kernels with row-parameter beta/v_th
290
+ # (constant across K steps — e.g., ParametricLIFNode scalars)
291
+ # ============================================================
292
+
293
+ @triton.jit
294
+ def _fused_plif_fwd_rowparam_kernel(
295
+ BETA_ROW_ptr, U_ptr, VTH_ROW_ptr, INIT_ptr,
296
+ SPIKE_ptr, VPOST_ptr,
297
+ K, num_cols,
298
+ BLOCK: tl.constexpr,
299
+ ):
300
+ """Fused PLIF forward with row-parameter beta and v_th.
301
+
302
+ beta and v_th are (*shape) — constant across K steps, loaded once into registers.
303
+ Reduces global memory reads from 3 per step (beta, u, v_th) to 1 (u only).
304
+ """
305
+ pid = tl.program_id(0)
306
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
307
+ mask = cols < num_cols
308
+
309
+ v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
310
+ beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
311
+ vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
312
+
313
+ for k in range(K):
314
+ off = k * num_cols + cols
315
+ u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32)
316
+
317
+ v_pre = beta * v + u
318
+ spike = tl.where(v_pre >= vth, 1.0, 0.0)
319
+ v = v_pre - vth * spike
320
+
321
+ tl.store(SPIKE_ptr + off, spike, mask=mask)
322
+ tl.store(VPOST_ptr + off, v, mask=mask)
323
+
324
+ @triton.jit
325
+ def _fused_plif_bwd_rowparam_kernel(
326
+ BETA_ROW_ptr, VTH_ROW_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr,
327
+ GRAD_SPIKE_ptr, GRAD_VPOST_ptr,
328
+ GRAD_BETA_ROW_ptr, GRAD_U_ptr, GRAD_VTH_ROW_ptr, GRAD_INIT_ptr,
329
+ K, num_cols, ALPHA,
330
+ BLOCK: tl.constexpr,
331
+ ):
332
+ """Fused PLIF backward with row-parameter beta/v_th.
333
+
334
+ Gradients for beta and v_th are accumulated over K steps (reduction in registers).
335
+ Returns grad_beta_row (*shape) and grad_v_th_row (*shape) instead of per-step gradients.
336
+ """
337
+ pid = tl.program_id(0)
338
+ cols = pid * BLOCK + tl.arange(0, BLOCK)
339
+ mask = cols < num_cols
340
+
341
+ beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
342
+ vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
343
+
344
+ acc = tl.zeros([BLOCK], dtype=tl.float32)
345
+ acc_grad_beta = tl.zeros([BLOCK], dtype=tl.float32)
346
+ acc_grad_vth = tl.zeros([BLOCK], dtype=tl.float32)
347
+
348
+ for k_rev in range(K):
349
+ k = K - 1 - k_rev
350
+ off = k * num_cols + cols
351
+
352
+ v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
353
+ spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
354
+
355
+ g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
356
+ g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
357
+
358
+ if k > 0:
359
+ v_prev = tl.load(
360
+ VPOST_ptr + (k - 1) * num_cols + cols,
361
+ mask=mask, other=0.0,
362
+ ).to(tl.float32)
363
+ else:
364
+ v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
365
+
366
+ # Sigmoid surrogate gradient
367
+ x = v_post - vth * (1.0 - spike)
368
+ neg_ax = -ALPHA * x
369
+ neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax)
370
+ sig = 1.0 / (1.0 + tl.exp(neg_ax))
371
+ sg = ALPHA * sig * (1.0 - sig)
372
+
373
+ total_gV = g_V + acc
374
+ grad_v_pre = g_s * sg + total_gV
375
+
376
+ tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask)
377
+
378
+ # Accumulate gradients for row parameters (reduction over K in registers)
379
+ acc_grad_beta += grad_v_pre * v_prev
380
+ acc_grad_vth += -g_s * sg - total_gV * spike
381
+
382
+ acc = grad_v_pre * beta
383
+
384
+ tl.store(GRAD_INIT_ptr + cols, acc, mask=mask)
385
+ tl.store(GRAD_BETA_ROW_ptr + cols, acc_grad_beta, mask=mask)
386
+ tl.store(GRAD_VTH_ROW_ptr + cols, acc_grad_vth, mask=mask)
387
+
388
+ class _TritonPLIFRowParamForward(torch.autograd.Function):
389
+ """Fused Triton PLIF with row-parameter beta/v_th.
390
+
391
+ For neurons with constant beta/v_th across K steps (ParametricLIFNode).
392
+ Eliminates expand+contiguous for beta/v_th tensors, reduces memory I/O by ~40%.
393
+ """
394
+
395
+ _BLOCK = 128
396
+
397
+ @staticmethod
398
+ def forward(ctx, beta_row, u, v_th_row, v_init, alpha):
399
+ beta_row_c = beta_row.contiguous()
400
+ u_c = u.contiguous()
401
+ v_th_row_c = v_th_row.contiguous()
402
+ v_init_c = v_init.contiguous()
403
+
404
+ K = u_c.shape[0]
405
+ num_cols = u_c[0].numel()
406
+
407
+ spike = torch.empty_like(u_c)
408
+ V_post = torch.empty_like(u_c)
409
+
410
+ BLOCK = _TritonPLIFRowParamForward._BLOCK
411
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
412
+
413
+ _fused_plif_fwd_rowparam_kernel[grid](
414
+ beta_row_c, u_c, v_th_row_c, v_init_c,
415
+ spike, V_post,
416
+ K, num_cols,
417
+ BLOCK=BLOCK,
418
+ )
419
+
420
+ if any(ctx.needs_input_grad[:4]):
421
+ ctx.save_for_backward(beta_row_c, v_th_row_c, v_init_c, V_post, spike)
422
+ ctx.K = K
423
+ ctx.num_cols = num_cols
424
+ ctx.alpha = alpha
425
+
426
+ return spike, V_post
427
+
428
+ @staticmethod
429
+ def backward(ctx, grad_spike, grad_V_post):
430
+ beta_row, v_th_row, v_init, V_post, spike = ctx.saved_tensors
431
+ K = ctx.K
432
+ num_cols = ctx.num_cols
433
+ alpha = ctx.alpha
434
+
435
+ if grad_spike is None:
436
+ grad_spike = torch.zeros_like(spike)
437
+ if grad_V_post is None:
438
+ grad_V_post = torch.zeros_like(V_post)
439
+
440
+ grad_spike_c = grad_spike.contiguous()
441
+ grad_V_post_c = grad_V_post.contiguous()
442
+
443
+ grad_beta_row = torch.empty_like(beta_row)
444
+ grad_u = torch.empty_like(V_post)
445
+ grad_v_th_row = torch.empty_like(v_th_row)
446
+ grad_v_init = torch.empty_like(v_init)
447
+
448
+ BLOCK = _TritonPLIFRowParamForward._BLOCK
449
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
450
+
451
+ _fused_plif_bwd_rowparam_kernel[grid](
452
+ beta_row, v_th_row, v_init, V_post, spike,
453
+ grad_spike_c, grad_V_post_c,
454
+ grad_beta_row, grad_u, grad_v_th_row, grad_v_init,
455
+ K, num_cols, float(alpha),
456
+ BLOCK=BLOCK,
457
+ )
458
+
459
+ return grad_beta_row, grad_u, grad_v_th_row, grad_v_init, None
460
+
461
+ class _TritonPLIFForward(torch.autograd.Function):
462
+ """Fused Triton PLIF forward + backward.
463
+
464
+ Single-pass sequential scan replaces the 3-phase approach:
465
+ Phase 1 (linear scan) + Phase 2 (spike iteration) + Phase 3 (correction)
466
+ → 1 fused kernel with inline spike detection + soft reset
467
+
468
+ Advantages:
469
+ - 1 kernel launch (vs 3-4 launches + ~10 element-wise ops)
470
+ - Exact computation (no iteration convergence issues)
471
+ - Less memory (no intermediate V_L, delta_S, delta_S_prev)
472
+ - Higher precision (fp32 accumulation, no bf16 intermediate store/load)
473
+ """
474
+
475
+ _BLOCK = 128
476
+
477
+ @staticmethod
478
+ def forward(ctx, beta, u, v_th, v_init, alpha):
479
+ beta_c = beta.contiguous()
480
+ u_c = u.contiguous()
481
+ v_th_c = v_th.contiguous()
482
+ v_init_c = v_init.contiguous()
483
+
484
+ K = beta_c.shape[0]
485
+ num_cols = beta_c[0].numel()
486
+
487
+ spike = torch.empty_like(u_c)
488
+ V_post = torch.empty_like(u_c)
489
+
490
+ BLOCK = _TritonPLIFForward._BLOCK
491
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
492
+
493
+ _fused_plif_fwd_kernel[grid](
494
+ beta_c, u_c, v_th_c, v_init_c,
495
+ spike, V_post,
496
+ K, num_cols,
497
+ BLOCK=BLOCK,
498
+ )
499
+
500
+ if any(ctx.needs_input_grad[:4]):
501
+ ctx.save_for_backward(beta_c, v_th_c, v_init_c, V_post, spike)
502
+ ctx.K = K
503
+ ctx.num_cols = num_cols
504
+ ctx.alpha = alpha
505
+
506
+ return spike, V_post
507
+
508
+ @staticmethod
509
+ def backward(ctx, grad_spike, grad_V_post):
510
+ beta, v_th, v_init, V_post, spike = ctx.saved_tensors
511
+ K = ctx.K
512
+ num_cols = ctx.num_cols
513
+ alpha = ctx.alpha
514
+
515
+ if grad_spike is None:
516
+ grad_spike = torch.zeros_like(spike)
517
+ if grad_V_post is None:
518
+ grad_V_post = torch.zeros_like(V_post)
519
+
520
+ grad_spike_c = grad_spike.contiguous()
521
+ grad_V_post_c = grad_V_post.contiguous()
522
+
523
+ grad_beta = torch.empty_like(beta)
524
+ grad_u = torch.empty_like(beta)
525
+ grad_v_th = torch.empty_like(v_th)
526
+ grad_v_init = torch.empty_like(v_init)
527
+
528
+ BLOCK = _TritonPLIFForward._BLOCK
529
+ grid = ((num_cols + BLOCK - 1) // BLOCK,)
530
+
531
+ _fused_plif_bwd_kernel[grid](
532
+ beta, v_th, v_init, V_post, spike,
533
+ grad_spike_c, grad_V_post_c,
534
+ grad_beta, grad_u, grad_v_th, grad_v_init,
535
+ K, num_cols, float(alpha),
536
+ BLOCK=BLOCK,
537
+ )
538
+
539
+ return grad_beta, grad_u, grad_v_th, grad_v_init, None
540
+
541
+
542
+ # ============================================================
543
+ # Hillis-Steele parallel prefix scan (CPU fallback)
544
+ # ============================================================
545
+
546
+ def hillis_steele_scan(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
547
+ """
548
+ Hillis-Steele 并行前缀扫描:计算仿射映射序列的所有前缀复合。
549
+
550
+ 给定仿射映射 f_k(x) = a[k] * x + b[k], k = 0, ..., K-1,
551
+ 计算前缀复合 F_k = f_k ∘ f_{k-1} ∘ ... ∘ f_0,
552
+ 使得 V[k] = F_k(v_init) = A[k] * v_init + B[k]。
553
+
554
+ 复合规则: (a2, b2) ∘ (a1, b1) = (a2 * a1, a2 * b1 + b2)
555
+
556
+ 实现使用 torch.cat 重建张量(无原地操作),完全兼容 autograd。
557
+
558
+ Args:
559
+ a: (K, *shape) — 乘性系数(如 β)
560
+ b: (K, *shape) — 加性项(如 α·I)
561
+
562
+ Returns:
563
+ A: (K, *shape) — 前缀积 A[k] = ∏_{j=0}^{k} a[j]
564
+ B: (K, *shape) — 前缀和 B[k] 使得 V[k] = A[k] * v_init + B[k]
565
+
566
+ 并行深度: O(log K)
567
+ 工作量: O(K * log K)
568
+ """
569
+ K = a.shape[0]
570
+ A = a
571
+ B = b
572
+
573
+ d = 1
574
+ while d < K:
575
+ A_new_tail = A[d:] * A[:-d]
576
+ B_new_tail = A[d:] * B[:-d] + B[d:]
577
+
578
+ A = torch.cat([A[:d], A_new_tail], dim=0)
579
+ B = torch.cat([B[:d], B_new_tail], dim=0)
580
+
581
+ d *= 2
582
+
583
+ return A, B
584
+
585
+
586
+ # ============================================================
587
+ # Public API: linear_recurrence
588
+ # ============================================================
589
+
590
+ def linear_recurrence(beta: torch.Tensor, u: torch.Tensor, v_init: torch.Tensor) -> torch.Tensor:
591
+ """
592
+ 求解线性递推: V[k] = beta[k] * V[k-1] + u[k], V[-1] = v_init
593
+
594
+ CUDA 后端: Triton fused kernel(1 次 kernel launch,O(K) 工作量)
595
+ CPU 后端: Hillis-Steele parallel scan(O(K log K) 工作量)
596
+
597
+ Args:
598
+ beta: (K, *shape) — 衰减系数,值域 (0, 1)
599
+ u: (K, *shape) — 输入项
600
+ v_init: (*shape) — 初始状态
601
+
602
+ Returns:
603
+ V: (K, *shape) — 所有 K 步的状态
604
+ """
605
+ if _HAS_TRITON and beta.is_cuda:
606
+ return _TritonLinearRecurrence.apply(beta, u, v_init)
607
+ # CPU fallback
608
+ A, B = hillis_steele_scan(beta, u)
609
+ V = A * v_init.unsqueeze(0) + B
610
+ return V
611
+
612
+
613
+ # ============================================================
614
+ # PLIF parallel forward (with spike iteration)
615
+ # ============================================================
616
+
617
+ def plif_parallel_forward(
618
+ beta: torch.Tensor,
619
+ u: torch.Tensor,
620
+ v_th: torch.Tensor,
621
+ v_init: torch.Tensor,
622
+ max_iter: int = 3,
623
+ surrogate_function=None,
624
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
625
+ """
626
+ PLIF 神经元的并行前向传播(soft reset,surrogate gradient 兼容)。
627
+
628
+ 求解:
629
+ V_pre[k] = beta[k] * V_post[k-1] + u[k]
630
+ s[k] = Θ(V_pre[k] - v_th[k])
631
+ V_post[k] = V_pre[k] - v_th[k] * s[k]
632
+
633
+ 方法:
634
+ Phase 1: 线性轨迹 parallel scan(有梯度)
635
+ Phase 2: spike 不动点迭代(detach,确定离散 spike pattern)
636
+ Phase 3: 用 converged spike pattern 重算 V_post(有梯度),
637
+ surrogate_function(V_pre - v_th) 生成可微 spike 输出
638
+
639
+ Args:
640
+ beta: (K, *shape) — 衰减系数
641
+ u: (K, *shape) — 输入 α·I
642
+ v_th: (K, *shape) — 动态阈值
643
+ v_init: (*shape) — 初始膜电位
644
+ max_iter: spike 不动点迭代次数上限
645
+ surrogate_function: SpikingJelly surrogate gradient 函数(如 surrogate.Sigmoid(alpha=4.0))
646
+ None 时退化为硬阈值(无梯度)
647
+
648
+ Returns:
649
+ spike: (K, *shape) — spike 模式(有 surrogate gradient)
650
+ V_post: (K, *shape) — 发放后膜电位
651
+ V_pre: (K, *shape) — 发放前膜电位(fused path 返回 None)
652
+ """
653
+ # Fused Triton path: single-pass sequential scan (exact, no iteration)
654
+ # Replaces 3-phase approach with 1 kernel launch — ~3x faster forward, ~5x faster backward
655
+ if (_HAS_TRITON and beta.is_cuda and surrogate_function is not None
656
+ and hasattr(surrogate_function, 'alpha')
657
+ and type(surrogate_function).__name__ == 'Sigmoid'):
658
+ alpha = float(surrogate_function.alpha)
659
+ spike, V_post = _TritonPLIFForward.apply(beta, u, v_th, v_init, alpha)
660
+ return spike, V_post, None
661
+
662
+ # Fallback: 3-phase approach (CPU, non-Sigmoid surrogates, or no surrogate)
663
+ # Phase 1: 线性轨迹 V_L (假设从不发放)
664
+ V_L = linear_recurrence(beta, u, v_init) # (K, *shape)
665
+
666
+ # Phase 2: Spike 不动点迭代(全部 detach,不建立梯度图)
667
+ # 目的:确定哪些神经元在哪些步发放(离散决策)
668
+ with torch.no_grad():
669
+ V_L_det = V_L.detach()
670
+ beta_det = beta.detach()
671
+ v_th_det = v_th.detach()
672
+ v_init_det = v_init.detach() if isinstance(v_init, torch.Tensor) else v_init
673
+
674
+ spike_pattern = (V_L_det >= v_th_det).float()
675
+
676
+ for _ in range(max_iter - 1):
677
+ # 计算 ΔS: ΔS[k] = beta[k] * ΔS[k-1] + v_th[k] * s[k]
678
+ delta_S = linear_recurrence(
679
+ beta_det, v_th_det * spike_pattern,
680
+ torch.zeros_like(v_init_det) if isinstance(v_init_det, torch.Tensor)
681
+ else torch.zeros_like(V_L_det[0]),
682
+ )
683
+
684
+ # ΔS_prev = ΔS[k-1](位移一步)
685
+ delta_S_prev = torch.zeros_like(delta_S)
686
+ delta_S_prev[1:] = delta_S[:-1]
687
+
688
+ # V_pre = V_L - beta * ΔS_prev
689
+ V_pre_det = V_L_det - beta_det * delta_S_prev
690
+
691
+ # 更新 spike
692
+ spike_new = (V_pre_det >= v_th_det).float()
693
+
694
+ # 收敛检查
695
+ if torch.equal(spike_new, spike_pattern):
696
+ break
697
+ spike_pattern = spike_new
698
+
699
+ # Phase 3: 用 converged spike pattern 重算 V_post(有完整梯度)
700
+ # spike_pattern 是 detached 的,作为常数参与计算
701
+ # 梯度通过 u, v_th, beta, v_init 流动
702
+ u_eff = u - v_th * spike_pattern
703
+ V_post = linear_recurrence(beta, u_eff, v_init) # (K, *shape)
704
+
705
+ # 重建 V_pre(有梯度,用于 surrogate gradient)
706
+ V_post_prev = torch.zeros_like(V_post)
707
+ if isinstance(v_init, torch.Tensor):
708
+ V_post_prev[0] = v_init
709
+ V_post_prev[1:] = V_post[:-1]
710
+ V_pre = beta * V_post_prev + u
711
+
712
+ # 生成可微 spike 输出
713
+ if surrogate_function is not None:
714
+ # forward: Heaviside(V_pre - v_th), backward: surrogate gradient
715
+ spike = surrogate_function(V_pre - v_th)
716
+ else:
717
+ # 退化模式:硬阈值,无梯度
718
+ spike = (V_pre >= v_th).float()
719
+
720
+ return spike, V_post, V_pre
721
+
722
+
723
+ def plif_rowparam_forward(
724
+ beta_row: torch.Tensor,
725
+ u: torch.Tensor,
726
+ v_th_row: torch.Tensor,
727
+ v_init: torch.Tensor,
728
+ surrogate_function=None,
729
+ ) -> tuple[torch.Tensor, torch.Tensor]:
730
+ """
731
+ 行参数 PLIF 前向:beta 和 v_th 在 K 步中保持恒定。
732
+
733
+ 比 plif_parallel_forward 快 ~40%(省去 expand+contiguous,减少 2/3 显存读取)。
734
+ 用于 ParametricLIFNode(固定 beta/v_th)或合并多个固定参数神经元。
735
+
736
+ Args:
737
+ beta_row: (*shape) — 每列的衰减率(所有 K 步相同)
738
+ u: (K, *shape) — 每步输入
739
+ v_th_row: (*shape) — 每列的阈值(所有 K 步相同)
740
+ v_init: (*shape) — 初始膜电位
741
+ surrogate_function: surrogate gradient 函数
742
+
743
+ Returns:
744
+ spike: (K, *shape) — spike 模式
745
+ V_post: (K, *shape) — 发放后膜电位
746
+ """
747
+ if (_HAS_TRITON and u.is_cuda and surrogate_function is not None
748
+ and hasattr(surrogate_function, 'alpha')
749
+ and type(surrogate_function).__name__ == 'Sigmoid'):
750
+ alpha = float(surrogate_function.alpha)
751
+ spike, V_post = _TritonPLIFRowParamForward.apply(
752
+ beta_row, u, v_th_row, v_init, alpha,
753
+ )
754
+ return spike, V_post
755
+
756
+ # Fallback: expand to full (K, *shape) and use standard path
757
+ K = u.shape[0]
758
+ beta = beta_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous()
759
+ v_th = v_th_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous()
760
+ spike, V_post, _ = plif_parallel_forward(beta, u, v_th, v_init, surrogate_function=surrogate_function)
761
+ return spike, V_post
762
+
763
+
764
+ def plif_fixed_param_forward(
765
+ beta,
766
+ u: torch.Tensor,
767
+ v_th,
768
+ v_init: torch.Tensor,
769
+ max_iter: int = 3,
770
+ surrogate_function=None,
771
+ ) -> tuple[torch.Tensor, torch.Tensor]:
772
+ """
773
+ 固定参数 PLIF 神经元的并行前向(如输出神经元、FFN 神经元)。
774
+
775
+ ParametricLIFNode 方程: V[k] = beta * V[k-1] + (1-beta) * x[k]
776
+ 其中 beta = 1/(1+exp(w)), 可为 scalar tensor(保持梯度流向 w)。
777
+
778
+ scalar/0-dim beta 和 v_th 使用 row-param 内核(无需 expand 到 (K, *shape))。
779
+
780
+ Args:
781
+ beta: 衰减率 — scalar float、0-dim tensor 或 (K, *shape) tensor
782
+ u: (K, *shape) — 输入(已乘以 (1-beta))
783
+ v_th: 阈值 — scalar float、0-dim tensor 或 (K, *shape) tensor
784
+ v_init: (*shape) — 初始膜电位
785
+ max_iter: spike 迭代次数
786
+ surrogate_function: surrogate gradient 函数
787
+
788
+ Returns:
789
+ spike: (K, *shape) — spike 模式
790
+ V_post: (K, *shape) — 发放后膜电位
791
+ """
792
+ K = u.shape[0]
793
+ shape = u.shape[1:]
794
+
795
+ # Row-param fast path: beta 和 v_th 都是 scalar/0-dim → 扩展为 (*shape) 行向量
796
+ beta_is_scalar = isinstance(beta, torch.Tensor) and beta.dim() == 0
797
+ beta_is_float = not isinstance(beta, torch.Tensor)
798
+ vth_is_scalar = isinstance(v_th, torch.Tensor) and v_th.dim() == 0
799
+ vth_is_float = not isinstance(v_th, torch.Tensor)
800
+
801
+ if (beta_is_scalar or beta_is_float) and (vth_is_scalar or vth_is_float):
802
+ # Build row vectors (*shape)
803
+ if beta_is_scalar:
804
+ beta_row = beta.expand(*shape).contiguous()
805
+ else:
806
+ beta_row = torch.full(shape, beta, device=u.device, dtype=u.dtype)
807
+ if vth_is_scalar:
808
+ v_th_row = v_th.expand(*shape).contiguous()
809
+ else:
810
+ v_th_row = torch.full(shape, v_th, device=u.device, dtype=u.dtype)
811
+ return plif_rowparam_forward(beta_row, u, v_th_row, v_init, surrogate_function)
812
+
813
+ # Full-tensor path: expand to (K, *shape) if needed
814
+ if isinstance(beta, torch.Tensor):
815
+ if beta.dim() == 0:
816
+ beta = beta.expand(K, *shape).contiguous()
817
+ else:
818
+ beta = torch.full_like(u, beta)
819
+
820
+ if isinstance(v_th, torch.Tensor):
821
+ if v_th.dim() == 0:
822
+ v_th = v_th.expand(K, *shape).contiguous()
823
+ else:
824
+ v_th = torch.full_like(u, v_th)
825
+
826
+ spike, V_post, _ = plif_parallel_forward(
827
+ beta, u, v_th, v_init, max_iter, surrogate_function,
828
+ )
829
+ return spike, V_post
atomic_ops/plif_node.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PLIFNode: D 维固定参数 PLIF 神经元(设计文档 5.5 "普通 SNN 神经元")
3
+
4
+ 与 SelectivePLIFNode 的区别:
5
+ SelectivePLIF: β(t), α(t), V_th(t) 由输入每步动态计算(选择性记忆)
6
+ PLIFNode: β, V_th 为 D 维可学习参数,训练后固定(信号转换)
7
+
8
+ 每个维度有独立的可学习参数:
9
+ β_d = sigmoid(w_d): 时间常数(衰减率)
10
+ V_th_d: 发放阈值
11
+
12
+ 动力学(与 ParametricLIF 一致):
13
+ V[t] = β · V[t-1] + (1-β) · x[t]
14
+ s[t] = Θ(V[t] - V_th) (surrogate gradient)
15
+ V[t] -= V_th · s[t] (soft reset)
16
+ """
17
+
18
+ import math
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from spikingjelly.activation_based import base, surrogate
23
+
24
+
25
+ class PLIFNode(base.MemoryModule):
26
+ """
27
+ D 维固定参数 PLIF 神经元。
28
+
29
+ Args:
30
+ dim: 神经元数量(每个维度独立参数)
31
+ init_tau: 初始时间常数 τ(β = 1 - 1/τ)
32
+ v_threshold: 初始发放阈值
33
+ surrogate_function: surrogate gradient 函数
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ dim: int,
39
+ init_tau: float = 2.0,
40
+ v_threshold: float = 0.5,
41
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
42
+ ):
43
+ super().__init__()
44
+ # D 维可学习参数(随机初始化,每个维度独立)
45
+ # w: 控制 β=sigmoid(w),随机产生不同时间常数
46
+ # init_w ± 0.5 → β ∈ ~[sigmoid(w-0.5), sigmoid(w+0.5)]
47
+ # tau=2.0 时 w=0, β ∈ ~[0.38, 0.62]
48
+ init_w = -math.log(init_tau - 1.0)
49
+ self.w = nn.Parameter(torch.empty(dim).normal_(init_w, 0.5))
50
+ # v_th: 发放阈值,U[0.5x, 1.5x] 均匀分布产生维度间多样性
51
+ self.v_th = nn.Parameter(torch.empty(dim).uniform_(
52
+ v_threshold * 0.5, v_threshold * 1.5,
53
+ ))
54
+ self.surrogate_function = surrogate_function
55
+ # 膜电位状态(functional.reset_net 时重置为 0.)
56
+ self.register_memory('v', 0.)
57
+
58
+ @property
59
+ def beta(self):
60
+ """D 维衰减率 β = sigmoid(w),值域 (0, 1)。"""
61
+ return torch.sigmoid(self.w)
62
+
63
+ def forward(self, x):
64
+ """
65
+ 单步前向传播。
66
+
67
+ V[t] = β · V[t-1] + (1-β) · x[t], spike = Θ(V-V_th), soft reset。
68
+
69
+ Args:
70
+ x: 输入电流, shape (batch, dim)
71
+
72
+ Returns:
73
+ spike: 二值脉冲, shape (batch, dim), 值域 {0, 1}
74
+ """
75
+ if isinstance(self.v, float):
76
+ self.v = torch.zeros_like(x)
77
+ beta = self.beta
78
+ self.v = beta * self.v + (1.0 - beta) * x
79
+ spike = self.surrogate_function(self.v - self.v_th)
80
+ self.v = self.v - spike * self.v_th # soft reset
81
+ return spike
atomic_ops/rms_norm.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RMSNorm: 残差流分支归一化(Pre-LN 模式)。
3
+
4
+ 位置: h → RMSNorm → PLIFNode → SNN子层 → out_proj → 残差
5
+ 作用: 控制送入 PLIFNode 的输入 scale,防止残差流漂移/爆炸。
6
+ 仅归一化分支输入,残差流本身不被归一化。
7
+
8
+ 对标 Qwen3/LLaMA 的 Pre-LN RMSNorm。
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ class RMSNorm(nn.Module):
16
+ """Root Mean Square Layer Normalization.
17
+
18
+ x_norm = x / RMS(x) * weight
19
+ RMS(x) = sqrt(mean(x^2) + eps)
20
+
21
+ Args:
22
+ dim: 归一化维度
23
+ eps: 数值稳定性
24
+ """
25
+
26
+ def __init__(self, dim: int, eps: float = 1e-6):
27
+ super().__init__()
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+ self.eps = eps
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ input_dtype = x.dtype
33
+ x = x.float()
34
+ variance = x.pow(2).mean(-1, keepdim=True)
35
+ x = x * torch.rsqrt(variance + self.eps)
36
+ return (self.weight * x).to(input_dtype)
atomic_ops/selective_plif.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SelectivePLIFNode: 动态参数的 PLIF 神经元
3
+
4
+ 与标准 ParametricLIFNode 的区别:
5
+ - β(t), α(t), V_th(t) 作为外部参数每步传入,不是内部 nn.Parameter
6
+ - 本神经元无可训练参数,所有学习发生在 SNNBlock 的调制网络中
7
+ - 仅支持 step_mode='s'(单步模式)
8
+ - 仅支持 soft reset(v_reset=None)
9
+
10
+ 状态方程:
11
+ V[t] = β(t) · V[t-1] + α(t) · I[t]
12
+ s[t] = Θ(V[t] - V_th(t)) (surrogate gradient)
13
+ V[t] -= V_th(t) · s[t] (soft reset)
14
+ """
15
+
16
+ import torch
17
+ from spikingjelly.activation_based import neuron, surrogate
18
+
19
+
20
+ class SelectivePLIFNode(neuron.BaseNode):
21
+ """
22
+ 隐状态空间的核心神经元。
23
+
24
+ 接收外部动态计算的 β(t), α(t), V_th(t),执行:
25
+ charge → fire → soft reset
26
+
27
+ Args:
28
+ surrogate_function: surrogate gradient 函数,默认 Sigmoid(alpha=4.0)
29
+ detach_reset: 是否在 reset 时 detach spike,默认 False
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
35
+ detach_reset: bool = False,
36
+ ):
37
+ # v_threshold=1.0 是占位值,实际使用外部传入的 v_th
38
+ # v_reset=None 触发 soft reset 模式,register_memory('v', 0.)
39
+ super().__init__(
40
+ v_threshold=1.0,
41
+ v_reset=None,
42
+ surrogate_function=surrogate_function,
43
+ detach_reset=detach_reset,
44
+ step_mode='s',
45
+ backend='torch',
46
+ store_v_seq=False,
47
+ )
48
+
49
+ def single_step_forward(
50
+ self,
51
+ x: torch.Tensor,
52
+ beta: torch.Tensor,
53
+ alpha: torch.Tensor,
54
+ v_th: torch.Tensor,
55
+ ) -> torch.Tensor:
56
+ """
57
+ 单步前向传播。
58
+
59
+ Args:
60
+ x: 输入电流 I[t], shape (batch, D*N)
61
+ beta: 衰减率 β(t), shape (batch, D*N), 值域 (0, 1)
62
+ alpha: 写入增益 α(t), shape (batch, D*N), 值域 R+
63
+ v_th: 动态阈值 V_th(t), shape (batch, D*N), 值域 R+
64
+
65
+ Returns:
66
+ spike: 二值脉冲 s[t], shape (batch, D*N), 值域 {0, 1}
67
+ """
68
+ # Phase 0: 首步将 v 从 float 扩展为与输入同形的张量
69
+ self.v_float_to_tensor(x)
70
+
71
+ # Phase 1: Charge — 膜电位更新
72
+ # V[t] = β(t) · V[t-1] + α(t) · I[t]
73
+ self.v = beta * self.v + alpha * x
74
+
75
+ # Phase 2: Fire — 使用动态 v_th(不是 self.v_threshold)
76
+ # spike = Heaviside(V[t] - V_th(t)),反向用 surrogate gradient
77
+ spike = self.surrogate_function(self.v - v_th)
78
+
79
+ # Phase 3: Soft Reset — V[t] -= V_th(t) · s[t]
80
+ if self.detach_reset:
81
+ spike_d = spike.detach()
82
+ else:
83
+ spike_d = spike
84
+ self.v = self.v - spike_d * v_th
85
+
86
+ return spike
87
+
88
+ def extra_repr(self) -> str:
89
+ return (
90
+ f'v_reset={self.v_reset}, '
91
+ f'detach_reset={self.detach_reset}, '
92
+ f'step_mode={self.step_mode}, '
93
+ f'surrogate={self.surrogate_function}'
94
+ )
atomic_ops/snn_block.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SNNBlock: 完整的 SNN 隐状态空间 Block(并行化版本)
3
+
4
+ 结构(每个 SNN 时间步):
5
+ spike_in {0,1}^D
6
+ ├─→ W_in → I[t] ∈ R^{D*N}
7
+ ├─→ W_β^(x) + b_β → σ → β(t)
8
+ ├─→ W_α^(x) + b_α → softplus → α(t)
9
+ ├─→ W_th^(x) + b_th → |·|+V_min → V_th(t)
10
+ ├─→ W_gate → sigmoid → gate ∈ (0,1)^D
11
+ └─→ W_skip → I_skip ∈ R^D
12
+
13
+ SelectivePLIF(I, β, α, V_th) → s[t] ∈ {0,1}^{D*N}
14
+
15
+ W_out · V_post[t] ⊙ gate + I_skip → 连续输出 ∈ R^D
16
+
17
+ 数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
18
+ """
19
+
20
+ import math
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from spikingjelly.activation_based import base, layer, surrogate
25
+
26
+ from .selective_plif import SelectivePLIFNode
27
+ from .parallel_scan import plif_parallel_forward
28
+
29
+
30
+ # ====== Fused modulation activations (torch.compile) ======
31
+ # Fuse sigmoid + softplus + abs + alpha*I into single kernel.
32
+ # 7-8 separate element-wise kernels → 1 fused kernel, ~4x speedup on DN-sized tensors.
33
+ # First call triggers JIT compilation (~seconds); cached for subsequent calls.
34
+
35
+ @torch.compile(backend='inductor', fullgraph=True)
36
+ def _fused_modulation(raw_beta, b_beta, raw_alpha, b_alpha, raw_th, b_th, v_th_min, I_all):
37
+ beta = torch.sigmoid(raw_beta + b_beta)
38
+ alpha = F.softplus(raw_alpha + b_alpha)
39
+ v_th = v_th_min + torch.abs(raw_th + b_th)
40
+ u = alpha * I_all
41
+ return beta, u, v_th
42
+
43
+
44
+ class SNNBlock(base.MemoryModule):
45
+ """
46
+ 单个 SNN Block(并行化)。
47
+
48
+ Args:
49
+ D: 可见维度(Block 间通信的维度)
50
+ N: 状态扩展因子(每个通道的隐神经元数)
51
+ v_th_min: 动态阈值下限
52
+ surrogate_function: surrogate gradient 函数
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ D: int,
58
+ N: int = 8,
59
+ v_th_min: float = 0.1,
60
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
61
+ ):
62
+ super().__init__()
63
+ self.D = D
64
+ self.N = N
65
+ self.v_th_min = v_th_min
66
+ DN = D * N
67
+
68
+ # ====== 六条并行输入投影(SNN 突触:spike 输入) ======
69
+ self.W_in = layer.Linear(D, DN, bias=False, step_mode='s')
70
+ self.W_beta_x = layer.Linear(D, DN, bias=False, step_mode='s')
71
+ self.W_alpha_x = layer.Linear(D, DN, bias=False, step_mode='s')
72
+ self.W_th_x = layer.Linear(D, DN, bias=False, step_mode='s')
73
+ self.W_gate = layer.Linear(D, D, bias=False, step_mode='s')
74
+ self.W_skip = layer.Linear(D, D, bias=False, step_mode='s')
75
+
76
+ # ====== β/α/V_th 仅依赖 spike_in(无 W^(V)·V 项) ======
77
+
78
+ # ====== 调制偏置(结构化初始化) ======
79
+ self.b_beta = nn.Parameter(torch.empty(DN))
80
+ self.b_alpha = nn.Parameter(torch.empty(DN))
81
+ self.b_th = nn.Parameter(torch.empty(DN))
82
+
83
+ # ====== 输出投影:D*N → D(SNN 突触) ======
84
+ self.W_out = layer.Linear(DN, D, bias=False, step_mode='s')
85
+
86
+ # ====== 隐状态空间神经元(D*N 个,动态参数) ======
87
+ self.hidden_neuron = SelectivePLIFNode(
88
+ surrogate_function=surrogate_function,
89
+ detach_reset=False,
90
+ )
91
+
92
+ # ====== 参数初始化 ======
93
+ self._initialize_parameters()
94
+
95
+ def _initialize_parameters(self):
96
+ """功能引导初始化。"""
97
+ D, N = self.D, self.N
98
+ K_ref = 16
99
+
100
+ # 目标 β 分布:多时间尺度 [0.80, 0.99]
101
+ beta_values = torch.linspace(0.80, 0.99, N)
102
+
103
+ # ====== 1. β 偏置:logit-spaced + 维度间随机扰动 ======
104
+ b_beta_per_n = torch.log(beta_values / (1.0 - beta_values))
105
+ # 以 per_n 值为均值,加 N(0, 0.1) 扰动打破 D 个通道的对称性
106
+ self.b_beta.data.copy_(b_beta_per_n.repeat(D))
107
+ self.b_beta.data.add_(torch.empty_like(self.b_beta).normal_(0, 0.1))
108
+
109
+ # ====== 2. α 偏置:softplus(0.5413) ≈ 1.0 + 维度间随机扰动 ======
110
+ # 以 0.5413 为均值,N(0, 0.1) 扰动 → α ∈ ~[0.7, 1.3]
111
+ self.b_alpha.data.normal_(0.5413, 0.1)
112
+
113
+ # ====== 3. W^(x) 权重 ======
114
+ for lin in [self.W_in, self.W_gate, self.W_skip, self.W_out]:
115
+ nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
116
+ for lin in [self.W_beta_x, self.W_alpha_x, self.W_th_x]:
117
+ nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
118
+ lin.weight.data.mul_(0.1)
119
+
120
+ # ====== 4. W_in 时间尺度缩放 ======
121
+ scale_per_n = torch.sqrt(1.0 - beta_values ** 2) # (N,)
122
+ scale_DN = scale_per_n.repeat(D) # (D*N,)
123
+ with torch.no_grad():
124
+ self.W_in.weight.mul_(scale_DN.unsqueeze(1))
125
+
126
+ # ====== 5. b_th:σ_V 校准 ======
127
+ # σ_V = sqrt(p/3) * sqrt(1 - β^{2K})
128
+ # 其中 p 是输入 firing rate。旧版假设 p=0.5(σ_I=0.408),
129
+ # 但实际 input_neuron firing rate 约 0.07~0.45,深层更低。
130
+ # 用 p=0.15 保守估计,避免 v_th 过高导致死神经元。
131
+ p_assumed = 0.15
132
+ sigma_I_base = math.sqrt(p_assumed / 3.0)
133
+ sigma_V_per_n = sigma_I_base * torch.sqrt(
134
+ 1.0 - beta_values ** (2 * K_ref)
135
+ )
136
+ target_p_fire = torch.linspace(0.25, 0.08, N)
137
+ z_scores = math.sqrt(2.0) * torch.erfinv(
138
+ 2.0 * (1.0 - target_p_fire) - 1.0
139
+ )
140
+ target_V_th = sigma_V_per_n * z_scores
141
+ b_th_per_n = torch.clamp(target_V_th - self.v_th_min, min=0.05)
142
+ # 以 per_n 值为均值,加 N(0, 0.02) 扰动打破 D 个通道的对称性
143
+ self.b_th.data.copy_(b_th_per_n.repeat(D))
144
+ self.b_th.data.add_(torch.empty_like(self.b_th).normal_(0, 0.02))
145
+
146
+ # ====== 6. W_out 发放率均衡缩放 ======
147
+ out_scale_per_n = 1.0 / torch.sqrt(target_p_fire)
148
+ out_scale_per_n = out_scale_per_n / out_scale_per_n.mean()
149
+ out_scale_DN = out_scale_per_n.repeat(D)
150
+ with torch.no_grad():
151
+ self.W_out.weight.mul_(out_scale_DN.unsqueeze(0))
152
+
153
+ def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor:
154
+ """
155
+ 并行前向传播:使用 parallel scan 处理全序列。
156
+
157
+ Args:
158
+ spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike
159
+
160
+ Returns:
161
+ continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出(V_post 经 W_out 投影)
162
+ """
163
+ TK, batch, D = spike_in_seq.shape
164
+ DN = self.D * self.N
165
+
166
+ # ====== Phase 1: 批量投影(全部 TK 帧同时计算)======
167
+ flat = spike_in_seq.reshape(TK * batch, D)
168
+
169
+ I_all = F.linear(flat, self.W_in.weight).reshape(TK, batch, DN)
170
+ raw_beta = F.linear(flat, self.W_beta_x.weight).reshape(TK, batch, DN)
171
+ raw_alpha = F.linear(flat, self.W_alpha_x.weight).reshape(TK, batch, DN)
172
+ raw_th = F.linear(flat, self.W_th_x.weight).reshape(TK, batch, DN)
173
+ gate_all = torch.sigmoid(
174
+ F.linear(flat, self.W_gate.weight).reshape(TK, batch, D)
175
+ )
176
+ I_skip_all = F.linear(flat, self.W_skip.weight).reshape(TK, batch, D)
177
+
178
+ # ====== Phase 1b: 融合激活(torch.compile → 单 kernel)======
179
+ beta_all, u_hidden, v_th_all = _fused_modulation(
180
+ raw_beta, self.b_beta, raw_alpha, self.b_alpha,
181
+ raw_th, self.b_th, self.v_th_min, I_all,
182
+ )
183
+
184
+ # 获取隐神经元初始状态
185
+ v_init_hidden = self.hidden_neuron.v
186
+ if isinstance(v_init_hidden, float):
187
+ v_init_hidden = torch.zeros(batch, DN, device=flat.device, dtype=flat.dtype)
188
+
189
+ s_hidden, V_post_hidden, _ = plif_parallel_forward(
190
+ beta_all, u_hidden, v_th_all, v_init_hidden, max_iter=3,
191
+ surrogate_function=self.hidden_neuron.surrogate_function,
192
+ )
193
+
194
+ # 更新隐神经元状态(保存末步供下次调用)
195
+ self.hidden_neuron.v = V_post_hidden[-1].detach()
196
+
197
+ # ====== Phase 4: 输出投影(V_post → W_out: 连续梯度直通 β)======
198
+ # 用 V_post(膜电压)代替 spike 作为 W_out 输入,消除 surrogate 梯度瓶颈:
199
+ # spike 路径: ∂spike/∂β = surrogate'(V-v_th) · V_prev ≈ 0(大部分时刻)
200
+ # V_post 路径: ∂V_post/∂β = V_prev(无 surrogate 阻断,每步都有梯度)
201
+ v_flat = V_post_hidden.reshape(TK * batch, DN)
202
+ I_out_all = F.linear(v_flat, self.W_out.weight).reshape(TK, batch, D)
203
+ I_total_all = I_out_all * gate_all + I_skip_all # (TK, batch, D)
204
+
205
+ # output_neuron 已移除:连续值由层级 K 帧聚合处理
206
+ return I_total_all # (TK, batch, D), 连续值
207
+
208
+ def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor:
209
+ """
210
+ 单步前向传播(用于调试/兼容)。
211
+
212
+ Args:
213
+ spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1}
214
+
215
+ Returns:
216
+ continuous_out: 连续输出, shape (batch, D)
217
+ """
218
+ V_prev = self.hidden_neuron.v
219
+ if isinstance(V_prev, float):
220
+ V_prev = torch.zeros(
221
+ spike_in.shape[0], self.D * self.N,
222
+ device=spike_in.device, dtype=spike_in.dtype,
223
+ )
224
+
225
+ I_t = self.W_in(spike_in)
226
+
227
+ # β 调制仅依赖 spike_in
228
+ beta = torch.sigmoid(self.W_beta_x(spike_in) + self.b_beta)
229
+ alpha = F.softplus(self.W_alpha_x(spike_in) + self.b_alpha)
230
+ v_th = self.v_th_min + torch.abs(self.W_th_x(spike_in) + self.b_th)
231
+
232
+ gate = torch.sigmoid(self.W_gate(spike_in))
233
+ I_skip = self.W_skip(spike_in)
234
+
235
+ s_hidden = self.hidden_neuron(I_t, beta, alpha, v_th)
236
+
237
+ # 用 V_post(膜电压)做输出投影,与 forward_parallel 一致
238
+ V_post = self.hidden_neuron.v # 发放+重置后的膜电位
239
+ I_out = self.W_out(V_post)
240
+ I_total = I_out * gate + I_skip
241
+
242
+ return I_total # 连续值
atomic_ops/snn_decoder_layer.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SNNDecoderLayer: 单个 SNN 解码层(Pre-LN 连续残差流 + 动态 K 帧聚合)
3
+
4
+ RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → 残差
5
+ RMSNorm → PLIF → SNNFFN → 动态K聚合 → out_proj → 残差
6
+
7
+ 动态 K:
8
+ - K 是最大步数(K_max),不是固定步数。不同 token 有效步数 ∈ [1, K_max]。
9
+ - 每个 token 的 K 帧 SNN 输出,学习自适应停止概率 p_halt
10
+ - PonderNet 几何分布加权:λ_k = p_k · ∏_{j<k}(1-p_j),归一化后加权聚合
11
+ - 不同 token 有效步数不同:简单 token 早停(E[K]小),复杂 token 用满步数
12
+ - ponder_cost 正则化:鼓励用更少步数完成简单 token 的处理
13
+
14
+ 数学推导:
15
+ 停止概率: p_k = σ(halt_proj(frame_k)) ∈ (0,1)
16
+ 生存概率: S_k = ∏_{j=1}^{k-1} (1 - p_j) — 到第 k 步还没停
17
+ 权重: λ_k = p_k · S_k — 恰好在第 k 步停止的概率
18
+ 归一化: λ̂_k = λ_k / Σ_k λ_k — 确保权重和为 1
19
+ 聚合: output = Σ_k λ̂_k · frame_k
20
+ 代价: E[K] = Σ_k k · λ̂_k — 期望步数
21
+
22
+ K_max 设计原则:
23
+ K_max 越大,模型对复杂 token 的处理能力越强(更多步数可用),
24
+ 但计算量和显存线性增长。K_max=32 允许 token 使用 1~32 步。
25
+ PonderNet 的 ponder_cost 正则化确保简单 token 不浪费步数。
26
+
27
+ K 帧层间聚合:
28
+ - SNN 子层输出 K 帧连续值(V_post 经投影),PonderNet 加权聚合为 1 per token
29
+ - 聚合后经 out_proj 投影,广播回 K 帧做残差
30
+ - 使 β 的时间动力学通过 K 帧聚合梯度有效传播
31
+
32
+ 对标 Qwen3DecoderLayer(Pre-LN 模式完全等价):
33
+ Qwen3: RMSNorm → Attention → residual → RMSNorm → MLP → residual
34
+ SNN: RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → residual
35
+ → RMSNorm → PLIF → SNNFFN → 动态K聚合 → out_proj → residual
36
+ """
37
+
38
+ import math
39
+
40
+ import torch
41
+ import torch.nn as nn
42
+ import torch.nn.functional as F
43
+ from spikingjelly.activation_based import base, surrogate
44
+
45
+ from .plif_node import PLIFNode
46
+ from .rms_norm import RMSNorm
47
+ from .snn_block import SNNBlock
48
+ from .snn_ffn import SNNFFN
49
+ from .parallel_scan import plif_rowparam_forward
50
+
51
+
52
+ # ====== Fused halt weight computation (torch.compile) ======
53
+ # 7-8 个独立 element-wise kernel → 单 fused kernel
54
+ # sigmoid + clamp + log1p + cumsum + exp + normalize
55
+ # 首次调用触发 JIT 编译(~秒级),后续调用走缓存
56
+
57
+ @torch.compile(backend='inductor', fullgraph=True)
58
+ def _fused_geometric_halt(halt_logits):
59
+ """融合计算 PonderNet 几何分布停止权重。
60
+
61
+ 输入: halt_logits (seq_len, K, batch) — halt_proj 的原始输出
62
+ 输出: halt_weights (seq_len, K, batch) — 归一化几何分布权重,sum=1
63
+
64
+ 数学: p_k = σ(logit_k), S_k = ∏_{j<k}(1-p_j), λ_k = p_k·S_k, λ̂_k = λ_k/Σλ
65
+ """
66
+ p_halt = torch.sigmoid(halt_logits).clamp(min=1e-7, max=1.0 - 1e-7)
67
+ log_1_minus_p = torch.log1p(-p_halt) # (seq_len, K, batch)
68
+ # Exclusive cumsum: log_survive[:, k, :] = Σ_{j<k} log(1-p_j)
69
+ # 避免 torch.cat: 用 cumsum([:, :-1]) 填充 [:, 1:]
70
+ log_survive = torch.zeros_like(log_1_minus_p)
71
+ log_survive[:, 1:, :] = torch.cumsum(log_1_minus_p[:, :-1, :], dim=1)
72
+ survive = torch.exp(log_survive) # (seq_len, K, batch)
73
+ halt_weights = p_halt * survive # λ_k = p_k · S_k
74
+ halt_weights = halt_weights / (halt_weights.sum(dim=1, keepdim=True) + 1e-8)
75
+ return halt_weights
76
+
77
+
78
+ class SNNDecoderLayer(base.MemoryModule):
79
+ """
80
+ 单个 SNN 解码层(连续残差流 + K 帧聚合版本)。
81
+
82
+ 层间传递连续值 h (TK, batch, D),通过 PLIF 神经元转换为 spike,
83
+ 输入 SNN 子层处理后,K 帧聚合为 1 per token,经 out_proj 投影,
84
+ 广播回 K 帧做残差连接。
85
+
86
+ K 帧聚合使 β 的时间动力学(控制 K 步内的膜电位演化)产生可微分的
87
+ token 级效应,解决 β 梯度为纯噪声的问题。
88
+
89
+ Args:
90
+ D: 可见维度
91
+ N: 状态扩展因子
92
+ D_ff: FFN 中间层维度
93
+ v_th_min: SNNBlock 动态阈值下限
94
+ ffn_v_threshold: SNNFFN gate/up 神经元阈值
95
+ K: 每 token 的 SNN 时间步数
96
+ num_layers: 总层数(用于残差输出缩放 + SNNFFN down_proj 缩放)
97
+ layer_idx: 当前层索引
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ D: int,
103
+ N: int,
104
+ D_ff: int,
105
+ v_th_min: float,
106
+ ffn_v_threshold: float,
107
+ K: int = 16,
108
+ num_layers: int = 1,
109
+ layer_idx: int = 0,
110
+ ):
111
+ super().__init__()
112
+ self.D = D
113
+ self.K = K
114
+
115
+ self.snn_block = SNNBlock(
116
+ D=D, N=N, v_th_min=v_th_min,
117
+ )
118
+ self.snn_ffn = SNNFFN(
119
+ D=D, D_ff=D_ff,
120
+ output_v_threshold=ffn_v_threshold,
121
+ num_layers=num_layers,
122
+ layer_idx=layer_idx,
123
+ )
124
+
125
+ # Pre-LN 分支归一化: h → RMSNorm → PLIFNode
126
+ self.block_norm = RMSNorm(D)
127
+ self.ffn_norm = RMSNorm(D)
128
+
129
+ # 输入神经元: RMSNorm(h) → V_post 膜电位激活(D 维可学习 β 和 V_th)
130
+ self.input_neuron1 = PLIFNode(
131
+ dim=D,
132
+ init_tau=2.0,
133
+ v_threshold=0.5,
134
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
135
+ )
136
+ self.input_neuron2 = PLIFNode(
137
+ dim=D,
138
+ init_tau=2.0,
139
+ v_threshold=0.5,
140
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
141
+ )
142
+
143
+ # 输出投影(突触): spike (D) → 连续空间 (D)
144
+ self.block_out_proj = nn.Linear(D, D, bias=False)
145
+ self.ffn_out_proj = nn.Linear(D, D, bias=False)
146
+
147
+ # ====== 动态 K: 停止投影(突触: SNN 输出 → 停止概率) ======
148
+ # halt_proj: D → 1,每步每 token 产生一个停止 logit
149
+ # PonderNet 几何分布加权,替代 uniform mean 聚合
150
+ self.block_halt = nn.Linear(D, 1, bias=True)
151
+ self.ffn_halt = nn.Linear(D, 1, bias=True)
152
+
153
+ # 残差输出缩放初始化(GPT-2 style: σ = 0.02 / √(2·num_layers))
154
+ std = 0.02 / math.sqrt(2 * num_layers)
155
+ nn.init.normal_(self.block_out_proj.weight, std=std)
156
+ nn.init.normal_(self.ffn_out_proj.weight, std=std)
157
+
158
+ # halt 初始化: 小权重 + 负偏置 → p_halt ≈ 0.03 → 接近 uniform 聚合
159
+ # σ(-3.5) ≈ 0.029, 几何分布归一化后 λ_1/λ_K ≈ 1.5, 接近均匀
160
+ for halt in [self.block_halt, self.ffn_halt]:
161
+ nn.init.xavier_uniform_(halt.weight)
162
+ halt.weight.data.mul_(0.01)
163
+ nn.init.constant_(halt.bias, -3.5)
164
+
165
+ def _input_neuron_parallel(self, input_neuron, x):
166
+ """
167
+ 输入 PLIF 神经元的 parallel scan 前向传播。
168
+
169
+ 完整 PLIF 动力学: V[t] = β·V[t-1] + (1-β)·x[t], spike = Θ(V-V_th), 软重置。
170
+ 输出膜电位泄漏量 (1-β)·V_post 作为激活值——即每步因指数衰减将泄漏的量。
171
+ 相比直接传递 V_post,泄漏量自然强调快响应神经元(大 1-β),
172
+ 抑制慢记忆神经元(小 1-β),实现隐式的时间尺度加权。
173
+
174
+ Args:
175
+ input_neuron: PLIFNode 实例(D 维可学习 β 和 V_th)
176
+ x: (TK, batch, D) — 连续值输入
177
+
178
+ Returns:
179
+ leak: (TK, batch, D) — 膜电位泄漏量 (1-β)·V_post
180
+ """
181
+ TK, batch, D = x.shape
182
+
183
+ beta = input_neuron.beta # (D,)
184
+ u = (1.0 - beta) * x # (D,) broadcast → (TK, batch, D)
185
+
186
+ v_init = input_neuron.v
187
+ if isinstance(v_init, float):
188
+ v_init = torch.zeros(batch, D, device=x.device, dtype=x.dtype)
189
+
190
+ beta_row = beta.unsqueeze(0).expand(batch, D).contiguous()
191
+ v_th_row = input_neuron.v_th.unsqueeze(0).expand(batch, D).contiguous()
192
+
193
+ spike, V_post = plif_rowparam_forward(
194
+ beta_row, u, v_th_row, v_init,
195
+ surrogate_function=input_neuron.surrogate_function,
196
+ )
197
+
198
+ input_neuron.v = V_post[-1].detach()
199
+ return (1.0 - beta) * V_post # 膜电位泄漏量
200
+
201
+ def _adaptive_aggregate(self, frames, halt_proj):
202
+ """
203
+ PonderNet 式自适应 K 帧聚合(动态 K 核心,torch.compile 融合优化)。
204
+
205
+ 每步计算停止概率 p_k,用几何分布权重加权聚合,
206
+ 使不同 token 有不同的有效步数。
207
+
208
+ 优化: _fused_geometric_halt 将 sigmoid+log1p+cumsum+exp+normalize
209
+ 融合为单 inductor kernel(参见 snn_block._fused_modulation 同一模式)。
210
+
211
+ 数学:
212
+ p_k = σ(halt_proj(frame_k)) — 停止概率
213
+ S_k = ∏_{j<k} (1-p_j) — 生存概率
214
+ λ_k = p_k · S_k — 几何分布权重
215
+ λ̂_k = λ_k / Σ λ_k — 归一化
216
+ output = Σ λ̂_k · frame_k — 加权聚合
217
+ E[K] = Σ k · λ̂_k — 期望步数(ponder cost)
218
+
219
+ Args:
220
+ frames: (seq_len, K, batch, D) — SNN 子层 K 帧输出
221
+ halt_proj: nn.Linear(D, 1) — 停止投影(突触)
222
+
223
+ Returns:
224
+ aggregated: (seq_len, batch, D) — 加权聚合结果
225
+ ponder_cost: scalar — 期望步数均值(正则化用)
226
+ """
227
+ seq_len, K, batch, D = frames.shape
228
+
229
+ # ====== 1. halt_proj matmul(cuBLAS)+ 融合几何权重(inductor) ======
230
+ halt_logits = halt_proj(frames).squeeze(-1) # (seq_len, K, batch)
231
+ halt_weights = _fused_geometric_halt(halt_logits) # (seq_len, K, batch), 归一化
232
+
233
+ # ====== 2. 加权聚合 ======
234
+ # (seq_len, K, batch, 1) × (seq_len, K, batch, D) → sum → (seq_len, batch, D)
235
+ aggregated = (frames * halt_weights.unsqueeze(-1)).sum(dim=1)
236
+
237
+ # ====== 3. Ponder cost: E[K] per token ======
238
+ steps = torch.arange(1, K + 1, device=frames.device, dtype=frames.dtype)
239
+ expected_k = (halt_weights * steps[None, :, None]).sum(dim=1) # (seq_len, batch)
240
+ ponder_cost = expected_k.mean() # scalar
241
+
242
+ return aggregated, ponder_cost, expected_k.detach()
243
+
244
+ def forward_parallel(self, h):
245
+ """
246
+ 并行前向传播:连续残差流 + 动态 K 帧聚合。
247
+
248
+ SNN 子层在 TK 维度处理(K 步时间动力学),输出后用 PonderNet
249
+ 自适应聚合 K 帧(不同 token 有效步数不同),经 out_proj 投影后
250
+ 广播回 TK 做残差。
251
+
252
+ Args:
253
+ h: (TK, batch, D) — 连续值输入
254
+
255
+ Returns:
256
+ h: (TK, batch, D) — 连续值输出
257
+ ponder_cost: scalar — 两个子层的平均期望步数(正则化用)
258
+ """
259
+ TK, batch, D = h.shape
260
+ K = self.K
261
+ seq_len = TK // K
262
+
263
+ # 子层 1: SNNBlock — RMSNorm → PLIFNode(V_post) → SNNBlock → 动态K聚合 → out_proj → 残差
264
+ v_in = self._input_neuron_parallel(self.input_neuron1, self.block_norm(h))
265
+ cont_block = self.snn_block.forward_parallel(v_in) # (TK, batch, D), 连续值
266
+
267
+ # 动态 K 帧聚合(PonderNet): (TK, batch, D) → (seq_len, K, batch, D) → 加权 → (seq_len, batch, D)
268
+ frames_block = cont_block.view(seq_len, K, batch, D)
269
+ combined_block, pc_block, ek_block = self._adaptive_aggregate(frames_block, self.block_halt)
270
+ res_block = self.block_out_proj(combined_block) # (seq_len, batch, D)
271
+ res_block = res_block - res_block.mean(dim=-1, keepdim=True) # 残差中心化
272
+
273
+ # 广播回 TK:每 token 的残差复制 K 份
274
+ h = h + res_block.repeat_interleave(K, dim=0)
275
+
276
+ # 子层 2: SNNFFN — RMSNorm → PLIFNode(V_post) → SNNFFN → 动态K聚合 → out_proj → 残差
277
+ v_in2 = self._input_neuron_parallel(self.input_neuron2, self.ffn_norm(h))
278
+ cont_ffn = self.snn_ffn.forward_parallel(v_in2) # (TK, batch, D), 连续值
279
+
280
+ frames_ffn = cont_ffn.view(seq_len, K, batch, D)
281
+ combined_ffn, pc_ffn, ek_ffn = self._adaptive_aggregate(frames_ffn, self.ffn_halt)
282
+ res_ffn = self.ffn_out_proj(combined_ffn)
283
+ res_ffn = res_ffn - res_ffn.mean(dim=-1, keepdim=True)
284
+
285
+ h = h + res_ffn.repeat_interleave(K, dim=0)
286
+
287
+ ponder_cost = (pc_block + pc_ffn) / 2.0 # 两个子层平均
288
+
289
+ # 存储 per-token E[K] 范围(诊断用,不影响计算图)
290
+ # ek_block/ek_ffn: (seq_len, batch), detached
291
+ with torch.no_grad():
292
+ all_ek = torch.cat([ek_block.flatten(), ek_ffn.flatten()])
293
+ self._ek_min = all_ek.min().item()
294
+ self._ek_max = all_ek.max().item()
295
+
296
+ return h, ponder_cost
297
+
298
+ def single_step_forward(self, h):
299
+ """
300
+ 单步前向传播:连续残差流。
301
+
302
+ 注意:单步模式无法做动态 K 聚合(每步独立处理)。
303
+ 训练和推理均使用 forward_parallel(含动态 K 聚合)。
304
+ 此方法仅用于调试。
305
+
306
+ Args:
307
+ h: (batch, D) — 连续值输入
308
+
309
+ Returns:
310
+ h: (batch, D) — 连续值输出
311
+ ponder_cost: scalar — 0.0(单步无 ponder cost)
312
+ """
313
+ # 子层 1: SNNBlock — RMSNorm → PLIFNode(leak) → SNNBlock → out_proj → 残差
314
+ _ = self.input_neuron1(self.block_norm(h)) # 触发 PLIF 动力学,更新 .v
315
+ v_in = (1.0 - self.input_neuron1.beta) * self.input_neuron1.v # 膜电位泄漏量
316
+ cont_block = self.snn_block.single_step_forward(v_in)
317
+ res_block = self.block_out_proj(cont_block)
318
+ h = h + res_block - res_block.mean(dim=-1, keepdim=True)
319
+
320
+ # 子层 2: SNNFFN — RMSNorm → PLIFNode(leak) → SNNFFN → out_proj → 残差
321
+ _ = self.input_neuron2(self.ffn_norm(h))
322
+ v_in2 = (1.0 - self.input_neuron2.beta) * self.input_neuron2.v # 膜电位泄漏量
323
+ cont_ffn = self.snn_ffn.single_step_forward(v_in2)
324
+ res_ffn = self.ffn_out_proj(cont_ffn)
325
+ h = h + res_ffn - res_ffn.mean(dim=-1, keepdim=True)
326
+
327
+ return h, torch.tensor(0.0, device=h.device)
atomic_ops/snn_ffn.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SNNFFN: SNN 等价的 Feed-Forward Network
3
+
4
+ 对标 Qwen3MLP 的 SwiGLU 结构:
5
+ Qwen3 MLP: down_proj( SiLU(gate_proj(x)) * up_proj(x) )
6
+ SNN FFN: down_proj( gate_V_post * up_V_post ) + skip
7
+
8
+ 膜电位门控(对标 SiLU gating):
9
+ gate/up 神经元完整 PLIF 动力学(积分+阈值+重置),
10
+ 输出膜电位 V_post 做连续乘法门控,替代 binary AND 门。
11
+
12
+ 信号流:
13
+ x → gate_proj → gate_neuron → V_post_gate
14
+ x → up_proj → up_neuron → V_post_up
15
+ V_post_gate × V_post_up → gated
16
+ down_proj(gated) + skip_proj(x) → 连续输出
17
+ """
18
+
19
+ import math
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from spikingjelly.activation_based import base, layer, surrogate
24
+
25
+ from .plif_node import PLIFNode
26
+ from .parallel_scan import plif_rowparam_forward
27
+
28
+
29
+ class SNNFFN(base.MemoryModule):
30
+ """
31
+ SNN 等价的 Feed-Forward Network。
32
+
33
+ Args:
34
+ D: 可见维度(输入/输出 spike 维度)
35
+ D_ff: 中间层维度(对标 Qwen3 intermediate_size)
36
+ output_v_threshold: 输出神经元阈值
37
+ num_layers: 总层数,用于 down_proj 缩放
38
+ layer_idx: 当前层索引
39
+ surrogate_function: surrogate gradient 函数
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ D: int,
45
+ D_ff: int,
46
+ output_v_threshold: float = 0.3,
47
+ num_layers: int = 1,
48
+ layer_idx: int = 0,
49
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
50
+ ):
51
+ super().__init__()
52
+ self.D = D
53
+ self.D_ff = D_ff
54
+
55
+ # ====== 三条投影路径(对标 SwiGLU: gate_proj, up_proj, down_proj) ======
56
+ self.gate_proj = layer.Linear(D, D_ff, bias=False, step_mode='s')
57
+ self.up_proj = layer.Linear(D, D_ff, bias=False, step_mode='s')
58
+ self.down_proj = layer.Linear(D_ff, D, bias=False, step_mode='s')
59
+
60
+ # ====== 残差路径 ======
61
+ self.skip_proj = layer.Linear(D, D, bias=False, step_mode='s')
62
+
63
+ # ====== 神经元(D 维或 D_ff 维可学习 β 和 V_th) ======
64
+ # gate_neuron: 门控发放
65
+ self.gate_neuron = PLIFNode(
66
+ dim=D_ff,
67
+ init_tau=2.0,
68
+ v_threshold=output_v_threshold,
69
+ surrogate_function=surrogate_function,
70
+ )
71
+ # up_neuron: 值发放
72
+ self.up_neuron = PLIFNode(
73
+ dim=D_ff,
74
+ init_tau=2.0,
75
+ v_threshold=output_v_threshold,
76
+ surrogate_function=surrogate_function,
77
+ )
78
+ # ====== 参数初始化 ======
79
+ self._initialize_parameters(num_layers)
80
+
81
+ def _initialize_parameters(self, num_layers: int):
82
+ """初始化投影权重。
83
+
84
+ - gate_proj, up_proj, skip_proj: Kaiming uniform
85
+ - down_proj: Kaiming uniform × 1/√(num_layers),防深层梯度爆炸
86
+ """
87
+ for lin in [self.gate_proj, self.up_proj, self.skip_proj]:
88
+ nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
89
+
90
+ nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
91
+ self.down_proj.weight.data.mul_(1.0 / math.sqrt(num_layers))
92
+
93
+ def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor:
94
+ """
95
+ 并行前向传播:使用 parallel scan 处理全序列。
96
+
97
+ 优化:
98
+ - gate_proj + up_proj 合并为单次 matmul(2 launch → 1)
99
+ - gate + up PLIF scan: row-param kernel(无需 expand+contiguous beta/v_th)
100
+ - u_merged: 向量缩放替代 cat(1次 broadcast multiply 替代 2次 scale + 1次 cat)
101
+
102
+ Args:
103
+ spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike
104
+
105
+ Returns:
106
+ continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出
107
+ """
108
+ TK, batch, D = spike_in_seq.shape
109
+ D_ff = self.D_ff
110
+ flat = spike_in_seq.reshape(TK * batch, D)
111
+
112
+ # ====== Phase 1: 批量投影(gate+up 合并为 1 次 matmul) ======
113
+ W_gate_up = torch.cat([self.gate_proj.weight, self.up_proj.weight], dim=0)
114
+ I_gate_up = F.linear(flat, W_gate_up).reshape(TK, batch, 2 * D_ff)
115
+ I_skip = F.linear(flat, self.skip_proj.weight).reshape(TK, batch, D)
116
+
117
+ # ====== Phase 2: Gate+Up 合并 PLIF scan(row-param kernel) ======
118
+ beta_gate = self.gate_neuron.beta # (D_ff,)
119
+ beta_up = self.up_neuron.beta # (D_ff,)
120
+ surr = self.gate_neuron.surrogate_function
121
+
122
+ # u_merged: 向量缩放(D_ff 维 β 直接 cat,无需 expand)
123
+ scale_row = torch.cat([1.0 - beta_gate, 1.0 - beta_up]) # (2*D_ff,)
124
+ u_merged = I_gate_up * scale_row # (TK, batch, 2*D_ff), broadcast
125
+
126
+ # beta_row / v_th_row: (batch, 2*D_ff) — D_ff 维可学习参数
127
+ beta_row = torch.cat([beta_gate, beta_up]) # (2*D_ff,)
128
+ beta_row = beta_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous()
129
+
130
+ v_th_row = torch.cat([self.gate_neuron.v_th, self.up_neuron.v_th]) # (2*D_ff,)
131
+ v_th_row = v_th_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous()
132
+
133
+ # v_init_merged: (batch, 2*D_ff)
134
+ v_init_gate = self.gate_neuron.v
135
+ if isinstance(v_init_gate, float):
136
+ v_init_gate = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype)
137
+ v_init_up = self.up_neuron.v
138
+ if isinstance(v_init_up, float):
139
+ v_init_up = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype)
140
+ v_init_merged = torch.cat([v_init_gate, v_init_up], dim=-1)
141
+
142
+ # Row-param PLIF scan: beta/v_th 从寄存器读取,不占显存带宽
143
+ spike_merged, V_post_merged = plif_rowparam_forward(
144
+ beta_row, u_merged, v_th_row, v_init_merged,
145
+ surrogate_function=surr,
146
+ )
147
+
148
+ # 膜电位泄漏量作为激活值: leak = (1-β) · V_post
149
+ gate_leak = V_post_merged[:, :, :D_ff] * (1.0 - beta_gate) # (TK, batch, D_ff)
150
+ up_leak = V_post_merged[:, :, D_ff:] * (1.0 - beta_up) # (TK, batch, D_ff)
151
+ self.gate_neuron.v = V_post_merged[-1, :, :D_ff].detach()
152
+ self.up_neuron.v = V_post_merged[-1, :, D_ff:].detach()
153
+
154
+ # ====== Phase 3: 连续门控(leak × leak,对标 SwiGLU)+ 降维 ======
155
+ gated = gate_leak * up_leak # (TK, batch, D_ff)
156
+ gated_flat = gated.reshape(TK * batch, D_ff)
157
+ I_out = F.linear(gated_flat, self.down_proj.weight).reshape(TK, batch, D) + I_skip
158
+
159
+ # output_neuron 已移除:连续值由层级 K 帧聚合处理
160
+ return I_out # (TK, batch, D), 连续值
161
+
162
+ def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor:
163
+ """
164
+ 单步前向传播。
165
+
166
+ Args:
167
+ spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1}
168
+
169
+ Returns:
170
+ continuous_out: 连续输出, shape (batch, D)
171
+ """
172
+ # 门控路径 — 膜电位泄漏量激活
173
+ _ = self.gate_neuron(self.gate_proj(spike_in))
174
+ gate_leak = (1.0 - self.gate_neuron.beta) * self.gate_neuron.v # leak
175
+
176
+ # 值路径 — 膜电位泄漏量激活
177
+ _ = self.up_neuron(self.up_proj(spike_in))
178
+ up_leak = (1.0 - self.up_neuron.beta) * self.up_neuron.v # leak
179
+
180
+ # 连续门控(对标 SwiGLU)
181
+ gated = gate_leak * up_leak
182
+
183
+ # 降维 + 残差
184
+ I_out = self.down_proj(gated) + self.skip_proj(spike_in) # R^D
185
+ return I_out # 连续值
config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "neuronspark",
3
+ "architectures": [
4
+ "NeuronSparkForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_neuronspark.NeuronSparkConfig",
8
+ "AutoModelForCausalLM": "modeling_neuronspark.NeuronSparkForCausalLM"
9
+ },
10
+ "vocab_size": 6144,
11
+ "D": 896,
12
+ "N": 8,
13
+ "K": 16,
14
+ "num_layers": 20,
15
+ "D_ff": 2688,
16
+ "v_th_min": 0.1,
17
+ "torch_dtype": "float32",
18
+ "transformers_version": "4.52.0",
19
+ "_training_step": 6500,
20
+ "_tokens_seen": 10459816
21
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework": "pytorch", "task": "other"}
configuration_neuronspark.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NeuronSpark 模型配置。"""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class NeuronSparkConfig(PretrainedConfig):
7
+ """
8
+ SNN 隐状态空间语言模型配置。
9
+
10
+ Args:
11
+ vocab_size: 词表大小
12
+ D: 隐层维度
13
+ N: 状态扩展因子(每通道隐神经元数)
14
+ K: 每 token 最大 SNN 时间步(PonderNet 动态决定有效步数)
15
+ num_layers: SNN 解码层数
16
+ D_ff: FFN 中间层维度
17
+ v_th_min: 动态阈值下限
18
+ """
19
+ model_type = "neuronspark"
20
+
21
+ def __init__(
22
+ self,
23
+ vocab_size=6144,
24
+ D=896,
25
+ N=8,
26
+ K=16,
27
+ num_layers=20,
28
+ D_ff=2688,
29
+ v_th_min=0.1,
30
+ **kwargs,
31
+ ):
32
+ self.D = D
33
+ self.N = N
34
+ self.K = K
35
+ self.num_layers = num_layers
36
+ self.D_ff = D_ff
37
+ self.v_th_min = v_th_min
38
+ super().__init__(vocab_size=vocab_size, **kwargs)
model.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SNNLanguageModel: SNN 隐状态空间语言模型(全膜电位 + 动态 K)
3
+
4
+ 架构(三段式):
5
+ model.encode(token_ids) → h_seq # 输入: embed → repeat K 次(可微分)
6
+ model.snn_forward(h_seq) → h_out, pc # SNN 核心: 20 层,全膜电位 + 动态 K 聚合
7
+ model.decode(h_out, seq) → logits # 输出: output_neuron(V_post) → K帧mean → proj → logits
8
+
9
+ 核心设计:
10
+ 1. 膜电位泄漏量:PLIFNode 输出 (1-β)·V_post(泄漏量),自然强调快响应神经元
11
+ 2. 动态 K:PonderNet 自适应停止,不同 token 不同有效步数
12
+ - 每层每子层学习 halt_proj(D→1),从 SNN 输出逐步计算停止概率
13
+ - 几何分布权重加权聚合,替代 uniform mean
14
+ - ponder_cost 正则化鼓励早停
15
+
16
+ 数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
17
+ """
18
+
19
+ import math
20
+ from dataclasses import dataclass
21
+ from typing import Optional
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from spikingjelly.activation_based import functional, surrogate
27
+ from torch.utils.checkpoint import checkpoint
28
+
29
+ from atomic_ops import SNNDecoderLayer
30
+ from atomic_ops.plif_node import PLIFNode
31
+ from atomic_ops.rms_norm import RMSNorm
32
+ from atomic_ops.parallel_scan import plif_rowparam_forward
33
+ # fp16_encode/fp16_decode 已移除: 全膜电位架构不需要 spike 编解码
34
+ from atomic_ops.lateral_inhibition import LateralInhibition
35
+
36
+
37
+ @dataclass
38
+ class SNNModelOutput:
39
+ """模型输出容器,对齐教程 CausalLMOutputWithPast 接口。"""
40
+ last_loss: Optional[torch.Tensor] = None
41
+ logits: Optional[torch.Tensor] = None
42
+ ponder_cost: Optional[torch.Tensor] = None # 动态 K: 平均期望步数
43
+
44
+
45
+ class SNNLanguageModel(nn.Module):
46
+ """
47
+ 从零训练的 SNN 隐状态空间语言模型(parallel scan)。
48
+
49
+ Args:
50
+ vocab_size: 词表大小(默认 6144,自训练 BPE)
51
+ D: 可见维度
52
+ N: 状态扩展因子
53
+ K: 每 token 最大 SNN 时间步数(K_max)。PonderNet 动态决定有效步数 ∈ [1, K]。
54
+ K 越大 → 复杂 token 可用更多步数,但计算量和显存线性增长。
55
+ num_layers: SNN 解码层数
56
+ D_ff: FFN 中间层维度
57
+ v_th_min: 动态阈值下限
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ vocab_size: int = 6144,
63
+ D: int = 1024,
64
+ N: int = 8,
65
+ K: int = 32,
66
+ num_layers: int = 20,
67
+ D_ff: int = 3072,
68
+ v_th_min: float = 0.1,
69
+ ):
70
+ super().__init__()
71
+ self.vocab_size = vocab_size
72
+ self.D = D
73
+ self.N = N
74
+ self.K = K
75
+ self.num_layers = num_layers
76
+ self.D_ff = D_ff
77
+
78
+ # ====== Embedding + Norm(全部可训练)======
79
+ self.embed_tokens = nn.Embedding(vocab_size, D)
80
+ self.norm = LateralInhibition(D)
81
+
82
+ # ====== 解码投影 ======
83
+ self.decode_proj = nn.Linear(D, D)
84
+
85
+ # ====== 输出 RMSNorm + 输出神经元 ======
86
+ self.output_norm = RMSNorm(D)
87
+ self.output_neuron = PLIFNode(
88
+ dim=D,
89
+ init_tau=2.0,
90
+ v_threshold=0.3,
91
+ surrogate_function=surrogate.Sigmoid(alpha=4.0),
92
+ )
93
+
94
+ # ====== SNN Decoder Layers ======
95
+ self.layers = nn.ModuleList([
96
+ SNNDecoderLayer(
97
+ D=D, N=N, D_ff=D_ff, v_th_min=v_th_min,
98
+ ffn_v_threshold=0.15,
99
+ K=K,
100
+ num_layers=num_layers,
101
+ layer_idx=i,
102
+ )
103
+ for i in range(num_layers)
104
+ ])
105
+
106
+ self._init_weights()
107
+
108
+ def _init_weights(self):
109
+ """初始化所有可训练权重(从零训练)。"""
110
+ nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=0.02)
111
+ nn.init.xavier_uniform_(self.decode_proj.weight)
112
+ nn.init.zeros_(self.decode_proj.bias)
113
+
114
+ def encode(self, token_ids: torch.Tensor) -> torch.Tensor:
115
+ """输入边界:token_ids → 连续值序列。
116
+
117
+ Embedding lookup,每 token 重复 K 次作为 SNN 时间步输入。
118
+ 梯度可通过 embedding 直接反传。
119
+
120
+ Returns: (seq_len*K, batch, D), 连续值
121
+ """
122
+ emb = self.embed_tokens(token_ids) # (batch, seq_len, D)
123
+ batch, seq_len, D = emb.shape
124
+ # 每 token 重复 K 次: (batch, seq_len, D) → (batch, seq_len*K, D) → (TK, batch, D)
125
+ emb_k = emb.unsqueeze(2).expand(-1, -1, self.K, -1).reshape(batch, seq_len * self.K, D)
126
+ return emb_k.permute(1, 0, 2).contiguous() # (TK, batch, D)
127
+
128
+ def snn_forward(self, spike_seq: torch.Tensor):
129
+ """SNN 核心:spike_seq → (h_out, ponder_cost)。
130
+
131
+ 纯 SNN 层计算,带梯度检查点。
132
+ 每层返回 (h, ponder_cost),ponder_cost 作为 checkpoint 输出保留梯度图。
133
+
134
+ Returns:
135
+ h: (seq_len*K, batch, D), 连续值
136
+ total_ponder_cost: scalar, 所有层平均期望步数
137
+ """
138
+ h = spike_seq
139
+ ponder_costs = []
140
+
141
+ def _layer_forward(layer_mod, x):
142
+ functional.reset_net(layer_mod)
143
+ return layer_mod.forward_parallel(x) # returns (h, ponder_cost)
144
+
145
+ for layer_module in self.layers:
146
+ h, pc = checkpoint(
147
+ _layer_forward, layer_module, h,
148
+ use_reentrant=False,
149
+ )
150
+ ponder_costs.append(pc)
151
+
152
+ total_ponder_cost = sum(ponder_costs) / len(ponder_costs)
153
+ return h, total_ponder_cost
154
+
155
+ def _output_neuron_parallel(self, h: torch.Tensor) -> torch.Tensor:
156
+ """输出 PLIF 神经元的 parallel scan 前向:连续 h → 膜电位泄漏量。
157
+
158
+ Args:
159
+ h: (TK, batch, D) 连续值(SNN 最后一层输出)
160
+
161
+ Returns:
162
+ leak: (TK, batch, D) 膜电位泄漏量 (1-β)·V_post
163
+ """
164
+ TK, batch, D = h.shape
165
+
166
+ beta = self.output_neuron.beta # (D,)
167
+ u = (1.0 - beta) * h # PLIF: u = (1-β) · x
168
+
169
+ v_init = self.output_neuron.v
170
+ if isinstance(v_init, float):
171
+ v_init = torch.zeros(batch, D, device=h.device, dtype=h.dtype)
172
+
173
+ beta_row = beta.unsqueeze(0).expand(batch, D).contiguous()
174
+ v_th_row = self.output_neuron.v_th.unsqueeze(0).expand(batch, D).contiguous()
175
+
176
+ spike, V_post = plif_rowparam_forward(
177
+ beta_row, u, v_th_row, v_init,
178
+ surrogate_function=self.output_neuron.surrogate_function,
179
+ )
180
+
181
+ self.output_neuron.v = V_post[-1].detach()
182
+ return (1.0 - beta) * V_post # 膜电位泄漏量
183
+
184
+ def decode(self, h_out: torch.Tensor, seq_len: int) -> torch.Tensor:
185
+ """输出边界:连续 h → 输出神经元(V_post) → K 帧聚合 → logits。
186
+
187
+ 梯度流: loss → logits → norm → decode_proj → K帧mean
188
+ → V_post(output_neuron) → h_out → SNN layers
189
+
190
+ Returns: (batch, seq_len, vocab_size)
191
+ """
192
+ h_out = self.output_norm(h_out) # RMSNorm: 控制 scale
193
+ v_out = self._output_neuron_parallel(h_out) # (TK, batch, D), V_post 膜电位
194
+ # K 帧聚合: (TK, batch, D) → (seq_len, K, batch, D) → mean → (seq_len, batch, D)
195
+ decoded = v_out.view(seq_len, self.K, -1, self.D).mean(dim=1)
196
+ decoded = decoded.permute(1, 0, 2) # (batch, seq_len, D)
197
+ h = self.decode_proj(decoded) # (batch, seq_len, D)
198
+ h = self.norm(h) # (batch, seq_len, D)
199
+ return F.linear(h, self.embed_tokens.weight) # (batch, seq_len, vocab)
200
+
201
+ @torch.no_grad()
202
+ def generate(
203
+ self,
204
+ prompt_ids: torch.Tensor,
205
+ max_new_tokens: int,
206
+ temperature: float = 1.0,
207
+ top_k: int = 50,
208
+ eos_token_id: Optional[int] = None,
209
+ ) -> torch.Tensor:
210
+ """
211
+ 自回归生成(SNN 神经元状态跨 token 连续维护)。
212
+
213
+ 1. Prefill: forward_parallel 并行处理 prompt,建立所有神经元 V 状态
214
+ 2. Autoregressive: 逐 token 生成,每 token 用 forward_parallel 处理 K 帧
215
+ 复用 Triton parallel scan kernel,神经元 V 状态跨 token 连续传递
216
+
217
+ Args:
218
+ prompt_ids: (batch, prompt_len) token IDs
219
+ max_new_tokens: 最大生成 token 数
220
+ temperature: 采样温度(<=0 = greedy)
221
+ top_k: top-k 采样(None/0 = 不限制)
222
+ eos_token_id: 遇到此 token 停止生成
223
+
224
+ Returns:
225
+ (batch, prompt_len + generated_len) 完整序列
226
+ """
227
+ batch, prompt_len = prompt_ids.shape
228
+
229
+ # 重置所有神经元(新序列的初始条件 V=0)
230
+ for layer_module in self.layers:
231
+ functional.reset_net(layer_module)
232
+ functional.reset_net(self.output_neuron)
233
+
234
+ # ====== Prefill: parallel 处理整个 prompt ======
235
+ h_seq = self.encode(prompt_ids) # (prompt_len*K, batch, D), 连续值
236
+ h = h_seq
237
+ for layer_module in self.layers:
238
+ h, _ = layer_module.forward_parallel(h) # 推理忽略 ponder_cost
239
+ # 此时所有层的所有神经元 .v 状态 = prompt 末尾状态
240
+
241
+ logits = self.decode(h, prompt_len)
242
+
243
+ # 采样第一个新 token
244
+ next_token = self._sample(logits[:, -1, :], temperature, top_k)
245
+ generated = [next_token]
246
+
247
+ # ====== Autoregressive: 逐 token,forward_parallel 处理 K 帧 ======
248
+ for _ in range(max_new_tokens - 1):
249
+ if eos_token_id is not None and (next_token == eos_token_id).all():
250
+ break
251
+
252
+ # 编码单 token → K 帧连续值(复用 encode)
253
+ frames = self.encode(next_token) # (K, batch, D)
254
+
255
+ # K 帧通过 SNN — 不 reset,神经元 .v 跨 token 连续传递
256
+ h = frames
257
+ for layer_module in self.layers:
258
+ h, _ = layer_module.forward_parallel(h)
259
+
260
+ logits = self.decode(h, 1)
261
+
262
+ next_token = self._sample(logits[:, -1, :], temperature, top_k)
263
+ generated.append(next_token)
264
+
265
+ return torch.cat([prompt_ids, torch.cat(generated, dim=1)], dim=1)
266
+
267
+ def _sample(self, logits: torch.Tensor, temperature: float = 1.0, top_k: int = None) -> torch.Tensor:
268
+ """从 logits 采样(temperature + top-k)。
269
+
270
+ Returns: (batch, 1)
271
+ """
272
+ if temperature <= 0:
273
+ return logits.argmax(dim=-1, keepdim=True)
274
+ logits = logits / temperature
275
+ if top_k is not None and top_k > 0:
276
+ top_k = min(top_k, logits.size(-1))
277
+ v, _ = torch.topk(logits, top_k)
278
+ logits[logits < v[:, [-1]]] = float('-inf')
279
+ probs = F.softmax(logits, dim=-1)
280
+ return torch.multinomial(probs, num_samples=1)
281
+
282
+ def forward(
283
+ self,
284
+ token_ids: torch.Tensor,
285
+ target_ids: torch.Tensor = None,
286
+ ) -> SNNModelOutput:
287
+ """
288
+ 前向传播(全膜电位 + 动态 K)。
289
+
290
+ encode → h_seq # 输入(embed repeat K 次,可微分)
291
+ snn_forward → h_out, pc # SNN 核心(全膜电位 + 动态 K 聚合)
292
+ decode → logits # 输出(V_post → K帧mean → proj → logits)
293
+
294
+ 梯度流:
295
+ embed_tokens → repeat K → SNN layers(V_post + 动态K)
296
+ → output_neuron(V_post) → K帧mean → decode_proj → logits(tied head)
297
+ ponder_cost: 动态 K 正则化,鼓励用更少步数处理简单 token
298
+ """
299
+ batch, seq_len = token_ids.shape
300
+
301
+ # 重置所有神经元状态
302
+ for layer_module in self.layers:
303
+ functional.reset_net(layer_module)
304
+ functional.reset_net(self.output_neuron)
305
+
306
+ # 三段式
307
+ spike_seq = self.encode(token_ids) # 输入边界
308
+ h_out, ponder_cost = self.snn_forward(spike_seq) # SNN 核心 + ponder cost
309
+ logits = self.decode(h_out, seq_len) # 输出边界
310
+
311
+ if target_ids is not None:
312
+ logits_flat = logits.reshape(-1, self.vocab_size)
313
+ targets_flat = target_ids.reshape(-1)
314
+ self.last_loss = F.cross_entropy(
315
+ logits_flat, targets_flat,
316
+ ignore_index=0, reduction='none',
317
+ )
318
+ return SNNModelOutput(
319
+ last_loss=self.last_loss,
320
+ ponder_cost=ponder_cost,
321
+ )
322
+
323
+ return SNNModelOutput(logits=logits, ponder_cost=ponder_cost)
324
+
325
+ def compensate_modulation_gradients(self, max_comp: float = 100.0):
326
+ """
327
+ Natural Gradient 补偿(两阶段)。
328
+
329
+ Phase 1: Sigmoid/softplus 饱和补偿
330
+ β = sigmoid(b_beta), sigmoid 在高 β 区(β=0.99, sigmoid'=0.01)梯度衰减 100x。
331
+ 补偿: grad /= activation'(b),等价于在 β/α 空间做梯度下降。
332
+
333
+ Phase 2: 层间梯度均衡
334
+ 残差链反向传播每层放大 ~1.17×,20 层累积 ~20× L0/L19 比。
335
+ 深层选择性参数(b_beta/b_alpha/b_th)梯度被压制,无法有效学习。
336
+ 修复: 将每层调制参数梯度 norm 归一化到所有层的几何均值。
337
+
338
+ 调用时机: scaler.unscale_(optimizer) 之后、clip_grad_norm_ 之前。
339
+
340
+ Args:
341
+ max_comp: 补偿因子上限(防止极端值导致不稳定)
342
+ """
343
+ # ====== Phase 1: Sigmoid/softplus 饱和补偿 ======
344
+ for layer_module in self.layers:
345
+ block = layer_module.snn_block
346
+
347
+ # b_beta: sigmoid 饱和补偿
348
+ # sigmoid'(z) = sigmoid(z) · (1 - sigmoid(z)) = β · (1-β)
349
+ if block.b_beta.grad is not None:
350
+ with torch.no_grad():
351
+ beta = torch.sigmoid(block.b_beta.data)
352
+ sigmoid_deriv = (beta * (1.0 - beta)).clamp(min=1.0 / max_comp)
353
+ block.b_beta.grad.div_(sigmoid_deriv)
354
+
355
+ # b_alpha: softplus 补偿(较温和,softplus'(z) = sigmoid(z))
356
+ if block.b_alpha.grad is not None:
357
+ with torch.no_grad():
358
+ softplus_deriv = torch.sigmoid(block.b_alpha.data).clamp(min=0.1)
359
+ block.b_alpha.grad.div_(softplus_deriv)
360
+
361
+ # b_th: |·| 导数为 ±1,无衰减,不需要补偿
362
+
363
+ # ====== Phase 2: 层间梯度均衡 ======
364
+ # 残差链 h = h + sublayer(h) 的反向路径 ∂h_{l+1}/∂h_l = I + ∂sublayer/∂h_l
365
+ # 每层放大 ~1.17×, 20 层累积 ~20× → L0 梯度远大于 L19
366
+ # 用几何均值归一化每层调制参数梯度 norm,消除残差放大效应
367
+ with torch.no_grad():
368
+ for param_name in ['b_beta', 'b_alpha', 'b_th']:
369
+ norms = []
370
+ params_list = []
371
+ for layer_module in self.layers:
372
+ p = getattr(layer_module.snn_block, param_name)
373
+ if p.grad is not None:
374
+ n = p.grad.norm().item()
375
+ if n > 1e-12:
376
+ norms.append(n)
377
+ params_list.append(p)
378
+
379
+ if len(norms) >= 2:
380
+ # 几何均值: exp(mean(log(norms))) — 对数尺度均衡,不受极端值影响
381
+ log_mean = sum(math.log(n) for n in norms) / len(norms)
382
+ geo_mean = math.exp(log_mean)
383
+ for p, n in zip(params_list, norms):
384
+ scale = geo_mean / n
385
+ scale = max(min(scale, max_comp), 1.0 / max_comp)
386
+ p.grad.mul_(scale)
387
+
388
+ def get_param_groups(self) -> dict[str, list[nn.Parameter]]:
389
+ """
390
+ 按功能分组的可训练参数。
391
+ """
392
+ groups = {
393
+ 'embedding': [self.embed_tokens.weight],
394
+ 'norm': [self.norm.gain],
395
+ 'decode': list(self.decode_proj.parameters()),
396
+ # 输出神经元
397
+ 'output_neuron': [self.output_neuron.w, self.output_neuron.v_th],
398
+ # RMSNorm(Pre-LN 分支归一化)
399
+ 'rms_norms': [self.output_norm.weight],
400
+ # 残差流组件
401
+ 'residual_projs': [],
402
+ 'input_neurons': [],
403
+ # 动态 K: 停止投影
404
+ 'halt_projs': [],
405
+ # SNNBlock 参数
406
+ 'W_in': [],
407
+ 'W_beta': [],
408
+ 'W_alpha': [],
409
+ 'W_th': [],
410
+ 'W_gate': [],
411
+ 'W_skip': [],
412
+ 'W_out': [],
413
+ 'b_beta': [],
414
+ 'b_alpha': [],
415
+ 'b_th': [],
416
+ 'block_output_neuron': [],
417
+ # SNNFFN 参数
418
+ 'ffn_gate_proj': [],
419
+ 'ffn_up_proj': [],
420
+ 'ffn_down_proj': [],
421
+ 'ffn_skip_proj': [],
422
+ 'ffn_neurons': [],
423
+ }
424
+
425
+ for layer_module in self.layers:
426
+ block = layer_module.snn_block
427
+ ffn = layer_module.snn_ffn
428
+
429
+ # 残差流组件
430
+ groups['residual_projs'].extend([
431
+ layer_module.block_out_proj.weight,
432
+ layer_module.ffn_out_proj.weight,
433
+ ])
434
+ groups['input_neurons'].extend([
435
+ layer_module.input_neuron1.w,
436
+ layer_module.input_neuron1.v_th,
437
+ layer_module.input_neuron2.w,
438
+ layer_module.input_neuron2.v_th,
439
+ ])
440
+ groups['rms_norms'].extend([
441
+ layer_module.block_norm.weight,
442
+ layer_module.ffn_norm.weight,
443
+ ])
444
+
445
+ # 动态 K: 停止投影参数
446
+ groups['halt_projs'].extend(list(layer_module.block_halt.parameters()))
447
+ groups['halt_projs'].extend(list(layer_module.ffn_halt.parameters()))
448
+
449
+ # SNNBlock 参数
450
+ groups['W_in'].append(block.W_in.weight)
451
+ groups['W_beta'].extend([block.W_beta_x.weight])
452
+ groups['W_alpha'].extend([block.W_alpha_x.weight])
453
+ groups['W_th'].extend([block.W_th_x.weight])
454
+ groups['W_gate'].append(block.W_gate.weight)
455
+ groups['W_skip'].append(block.W_skip.weight)
456
+ groups['W_out'].append(block.W_out.weight)
457
+ groups['b_beta'].append(block.b_beta)
458
+ groups['b_alpha'].append(block.b_alpha)
459
+ groups['b_th'].append(block.b_th)
460
+
461
+ # SNNFFN 参数
462
+ groups['ffn_gate_proj'].append(ffn.gate_proj.weight)
463
+ groups['ffn_up_proj'].append(ffn.up_proj.weight)
464
+ groups['ffn_down_proj'].append(ffn.down_proj.weight)
465
+ groups['ffn_skip_proj'].append(ffn.skip_proj.weight)
466
+ groups['ffn_neurons'].extend([
467
+ ffn.gate_neuron.w, ffn.gate_neuron.v_th,
468
+ ffn.up_neuron.w, ffn.up_neuron.v_th,
469
+ ])
470
+
471
+ return groups
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58fac5ad9b11e95d7601686846399b33f84593d29903d68c3dc38c05e4b93d16
3
+ size 3496634368
modeling_neuronspark.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NeuronSpark: SNN 隐状态空间语言模型 — HuggingFace 接口
3
+
4
+ 用法:
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ "checkpoints_sft/", trust_remote_code=True,
9
+ )
10
+ tokenizer = AutoTokenizer.from_pretrained("checkpoints_sft/")
11
+ """
12
+
13
+ from typing import Optional
14
+
15
+ import torch
16
+ from transformers import PreTrainedModel, GenerationMixin
17
+ from transformers.modeling_outputs import CausalLMOutputWithPast
18
+
19
+ from configuration_neuronspark import NeuronSparkConfig
20
+ from model import SNNLanguageModel
21
+
22
+
23
+ class NeuronSparkForCausalLM(PreTrainedModel, GenerationMixin):
24
+ """
25
+ SNN 语言模型 — CausalLM 接口。
26
+
27
+ 封装 SNNLanguageModel,提供 HuggingFace 标准接口:
28
+ - forward(input_ids, labels) → CausalLMOutputWithPast
29
+ - generate() 支持(通过 GenerationMixin)
30
+ """
31
+ config_class = NeuronSparkConfig
32
+ supports_gradient_checkpointing = True
33
+
34
+ def __init__(self, config: NeuronSparkConfig):
35
+ super().__init__(config)
36
+ self.model = SNNLanguageModel(
37
+ vocab_size=config.vocab_size,
38
+ D=config.D,
39
+ N=config.N,
40
+ K=config.K,
41
+ num_layers=config.num_layers,
42
+ D_ff=config.D_ff,
43
+ v_th_min=config.v_th_min,
44
+ )
45
+
46
+ def get_input_embeddings(self):
47
+ return self.model.embed_tokens
48
+
49
+ def set_input_embeddings(self, value):
50
+ self.model.embed_tokens = value
51
+
52
+ def get_output_embeddings(self):
53
+ # tied head: 输出复用 embed_tokens.weight
54
+ return self.model.embed_tokens
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.Tensor,
59
+ labels: Optional[torch.Tensor] = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ **kwargs,
62
+ ) -> CausalLMOutputWithPast:
63
+ """
64
+ 前向传播。
65
+
66
+ Args:
67
+ input_ids: (batch, seq_len) token IDs
68
+ labels: (batch, seq_len) 目标 token IDs(可选,用于计算 loss)
69
+ attention_mask: 兼容参数(SNN 无 attention,忽略)
70
+ """
71
+ if labels is not None:
72
+ out = self.model(input_ids, target_ids=labels)
73
+ # 计算 masked loss
74
+ loss_mask = (labels != 0).float().view(-1)
75
+ loss = (out.last_loss * loss_mask).sum() / loss_mask.sum()
76
+ # 加 ponder cost
77
+ if out.ponder_cost is not None:
78
+ loss = loss + 0.01 * out.ponder_cost
79
+ return CausalLMOutputWithPast(loss=loss)
80
+ else:
81
+ out = self.model(input_ids)
82
+ return CausalLMOutputWithPast(logits=out.logits)
83
+
84
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
85
+ """generate() 所需的输入准备。"""
86
+ return {"input_ids": input_ids}
87
+
88
+ @torch.no_grad()
89
+ def generate(
90
+ self,
91
+ input_ids: torch.Tensor,
92
+ max_new_tokens: int = 256,
93
+ temperature: float = 1.0,
94
+ top_k: int = 50,
95
+ eos_token_id: Optional[int] = None,
96
+ **kwargs,
97
+ ) -> torch.Tensor:
98
+ """
99
+ 自回归生成(直接调用 SNN 的 generate 方法)。
100
+ """
101
+ return self.model.generate(
102
+ prompt_ids=input_ids,
103
+ max_new_tokens=max_new_tokens,
104
+ temperature=temperature,
105
+ top_k=top_k,
106
+ eos_token_id=eos_token_id,
107
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|im_start|>",
3
+ "eos_token": "<|im_end|>",
4
+ "unk_token": "<unk>",
5
+ "pad_token": "<|im_end|>",
6
+ "additional_special_tokens": [
7
+ "<s>",
8
+ "</s>"
9
+ ]
10
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "bos_token": "<|im_start|>",
6
+ "eos_token": "<|im_end|>",
7
+ "pad_token": "<|im_end|>",
8
+ "unk_token": "<unk>",
9
+ "model_max_length": 1000000000000000019884624838656,
10
+ "clean_up_tokenization_spaces": false,
11
+ "tokenizer_class": "PreTrainedTokenizerFast",
12
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
13
+ }