Initial release: NeuronSpark-0.9B-Chat instruction-tuned SNN language model
Browse files- LICENSE +190 -0
- README.md +126 -0
- atomic_ops/__init__.py +9 -0
- atomic_ops/fp16_codec.py +98 -0
- atomic_ops/lateral_inhibition.py +215 -0
- atomic_ops/parallel_scan.py +829 -0
- atomic_ops/plif_node.py +81 -0
- atomic_ops/rms_norm.py +36 -0
- atomic_ops/selective_plif.py +94 -0
- atomic_ops/snn_block.py +242 -0
- atomic_ops/snn_decoder_layer.py +327 -0
- atomic_ops/snn_ffn.py +185 -0
- config.json +21 -0
- configuration.json +1 -0
- configuration_neuronspark.py +38 -0
- model.py +471 -0
- model.safetensors +3 -0
- modeling_neuronspark.py +107 -0
- special_tokens_map.json +10 -0
- tokenizer.json +0 -0
- tokenizer_config.json +13 -0
LICENSE
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to the Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by the Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding any notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
Copyright 2025 Zhengzheng Tang
|
| 179 |
+
|
| 180 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 181 |
+
you may not use this file except in compliance with the License.
|
| 182 |
+
You may obtain a copy of the License at
|
| 183 |
+
|
| 184 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 185 |
+
|
| 186 |
+
Unless required by applicable law or agreed to in writing, software
|
| 187 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 188 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 189 |
+
See the License for the specific language governing permissions and
|
| 190 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- zh
|
| 5 |
+
base_model: Brain2nd/NeuronSpark-0.9B
|
| 6 |
+
library_name: transformers
|
| 7 |
+
tags:
|
| 8 |
+
- snn
|
| 9 |
+
- spiking-neural-network
|
| 10 |
+
- text-generation
|
| 11 |
+
- neuromorphic
|
| 12 |
+
- chat
|
| 13 |
+
pipeline_tag: text-generation
|
| 14 |
+
---
|
| 15 |
+
# NeuronSpark-0.9B-Chat
|
| 16 |
+
|
| 17 |
+
## Introduction
|
| 18 |
+
|
| 19 |
+
**NeuronSpark-0.9B-Chat** is the **instruction-tuned chat version** of NeuronSpark-0.9B — a 0.87-billion parameter language model built entirely on Spiking Neural Networks (SNNs). It has been fine-tuned on a small subset of BelleGroup 3.5M Chinese instructions to enable basic dialogue capabilities.
|
| 20 |
+
|
| 21 |
+
> **Note on training data**: Due to limited compute resources, both pretraining and SFT used only **small subsets** of their respective datasets (pretrain ~1.4B of ~10B tokens; SFT ~6.5K steps of ~3.5M samples). Despite this minimal data budget, the model demonstrates coherent Chinese dialogue — validating that pure SNN architectures can learn language from scratch. We plan to scale training with more data and compute in future work.
|
| 22 |
+
|
| 23 |
+
For the pretrained base model, see [NeuronSpark-0.9B](https://modelscope.cn/models/Brain2nd/NeuronSpark-0.9B).
|
| 24 |
+
|
| 25 |
+
## Model Details
|
| 26 |
+
|
| 27 |
+
| Attribute | Value |
|
| 28 |
+
|-----------|-------|
|
| 29 |
+
| Parameters | 874M |
|
| 30 |
+
| Architecture | SNN Hidden State Space Model |
|
| 31 |
+
| Hidden Dimension (D) | 896 |
|
| 32 |
+
| Layers | 20 |
|
| 33 |
+
| SNN Timesteps (K) | 16 (PonderNet adaptive) |
|
| 34 |
+
| State Expansion (N) | 8 |
|
| 35 |
+
| FFN Dimension | 2688 |
|
| 36 |
+
| Vocabulary | 6144 (custom BPE) |
|
| 37 |
+
| Context Length | 512 tokens |
|
| 38 |
+
| Base Model | NeuronSpark-0.9B (pretrained 85K steps) |
|
| 39 |
+
| SFT Data | BelleGroup train_3.5M_CN |
|
| 40 |
+
| SFT Steps | 6,500 |
|
| 41 |
+
| Chat Template | ChatML |
|
| 42 |
+
| License | Apache 2.0 |
|
| 43 |
+
|
| 44 |
+
## Quickstart
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 48 |
+
|
| 49 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 50 |
+
"Brain2nd/NeuronSpark-0.9B-Chat",
|
| 51 |
+
trust_remote_code=True,
|
| 52 |
+
)
|
| 53 |
+
tokenizer = AutoTokenizer.from_pretrained("Brain2nd/NeuronSpark-0.9B-Chat")
|
| 54 |
+
|
| 55 |
+
# Chat
|
| 56 |
+
messages = [
|
| 57 |
+
{"role": "system", "content": "你是一个AI助手"},
|
| 58 |
+
{"role": "user", "content": "中国的首都是哪里?"},
|
| 59 |
+
]
|
| 60 |
+
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 61 |
+
input_ids = tokenizer(text, return_tensors="pt")["input_ids"]
|
| 62 |
+
|
| 63 |
+
output_ids = model.generate(
|
| 64 |
+
input_ids,
|
| 65 |
+
max_new_tokens=256,
|
| 66 |
+
temperature=0.1,
|
| 67 |
+
top_k=10,
|
| 68 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Extract assistant response
|
| 72 |
+
full_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
|
| 73 |
+
response = full_text.split("assistant\n")[-1].replace("<|im_end|>", "").strip()
|
| 74 |
+
print(response)
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
**Example Output:**
|
| 78 |
+
```
|
| 79 |
+
Q: 中国的首都是哪里?
|
| 80 |
+
A: 中国的首都在北京。
|
| 81 |
+
|
| 82 |
+
Q: 你好呀
|
| 83 |
+
A: 请问您需要什么样的帮助?
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## Architecture Highlights
|
| 87 |
+
|
| 88 |
+
- **Pure SNN**: No attention, no standard MLP — all computation via PLIF (Parametric Leaky Integrate-and-Fire) neurons
|
| 89 |
+
- **Membrane Potential Leakage Activation**: PLIFNode outputs `(1-β)·V_post` (leak current), naturally emphasizing fast-responding neurons
|
| 90 |
+
- **Selective State Space**: Hidden neurons with input-dependent dynamic β(t), α(t), V_th(t)
|
| 91 |
+
- **PonderNet Adaptive K**: Each token dynamically decides how many SNN timesteps to use
|
| 92 |
+
- **Triton Fused Kernels**: Custom PLIF forward/backward kernels for efficient parallel scan
|
| 93 |
+
- **ChatML Template**: Compatible with standard chat formatting
|
| 94 |
+
|
| 95 |
+
## Requirements
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
pip install torch transformers spikingjelly safetensors
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## Limitations
|
| 102 |
+
|
| 103 |
+
- **Context length**: 512 tokens (limited by training configuration)
|
| 104 |
+
- **Knowledge**: Trained on Chinese corpus only; limited factual accuracy
|
| 105 |
+
- **Repetition**: May generate repetitive text for complex queries
|
| 106 |
+
- **Scale**: 0.9B parameters — significantly smaller than state-of-the-art chat models
|
| 107 |
+
|
| 108 |
+
This is a **research model** demonstrating that SNN architectures can achieve basic language understanding and dialogue, even with very limited training data. It is not intended for production use. We plan to continue scaling with more data and compute.
|
| 109 |
+
|
| 110 |
+
## Citation
|
| 111 |
+
|
| 112 |
+
```bibtex
|
| 113 |
+
@misc{neuronspark2025,
|
| 114 |
+
title={NeuronSpark: A Spiking Neural Network Language Model with Selective State Space Dynamics},
|
| 115 |
+
author={Zhengzheng Tang},
|
| 116 |
+
year={2025},
|
| 117 |
+
url={https://github.com/Brain2nd/NeuronSpark}
|
| 118 |
+
}
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
## Contact
|
| 122 |
+
|
| 123 |
+
- **Author**: Zhengzheng Tang
|
| 124 |
+
- **Email**: zztangbu@bu.edu
|
| 125 |
+
- **GitHub**: [Brain2nd/NeuronSpark](https://github.com/Brain2nd/NeuronSpark)
|
| 126 |
+
|
atomic_ops/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .selective_plif import SelectivePLIFNode
|
| 2 |
+
from .plif_node import PLIFNode
|
| 3 |
+
from .lateral_inhibition import LateralInhibition
|
| 4 |
+
from .snn_block import SNNBlock
|
| 5 |
+
from .snn_ffn import SNNFFN
|
| 6 |
+
from .snn_decoder_layer import SNNDecoderLayer
|
| 7 |
+
from .parallel_scan import hillis_steele_scan, linear_recurrence, plif_parallel_forward
|
| 8 |
+
from .fp16_codec import fp16_encode, fp16_decode
|
| 9 |
+
from .rms_norm import RMSNorm
|
atomic_ops/fp16_codec.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FP16 二进制编码/解码 — 模型边界操作(无可训练参数)。
|
| 3 |
+
|
| 4 |
+
IEEE 754 float16 位布局(K=16 时间步):
|
| 5 |
+
时间步: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
|
| 6 |
+
位: sign E4 E3 E2 E1 E0 M9 M8 M7 M6 M5 M4 M3 M2 M1 M0
|
| 7 |
+
含义: 符号 ←── 指数(bias=15) ──→ ←────────── 尾数(隐含 1.xxx) ──────────→
|
| 8 |
+
|
| 9 |
+
编码: embedding → IEEE 754 float16 位提取 → 16 帧二值 spike(detach,固定预处理)
|
| 10 |
+
解码: 16 帧二值 spike → IEEE 754 位重建 → 连续值(可微分,梯度通过 surrogate grad 传播)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def fp16_encode(emb: Tensor, K: int = 16) -> Tensor:
|
| 18 |
+
"""FP16 二进制编码(模型边界操作,固定预处理)。
|
| 19 |
+
|
| 20 |
+
将连续 embedding 转为 IEEE 754 float16 位模式,作为 SNN 的 spike 输入。
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
emb: (batch, seq_len, D) 连续 embedding
|
| 24 |
+
K: 时间步数(必须为 16,对应 float16 的 16 位)
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
spike_seq: (seq_len*K, batch, D) 二值 {0, 1}, detached
|
| 28 |
+
"""
|
| 29 |
+
batch, seq_len, D = emb.shape
|
| 30 |
+
|
| 31 |
+
# 转为 float16 获取 IEEE 754 位模式
|
| 32 |
+
# clamp 防止 overflow 产生 Inf(float16 最大值 65504)
|
| 33 |
+
emb_fp16 = emb.float().clamp(-65504.0, 65504.0).half()
|
| 34 |
+
bits_int = emb_fp16.view(torch.int16) # (batch, seq_len, D)
|
| 35 |
+
|
| 36 |
+
# 提取 16 位(MSB first: sign, exponent, mantissa)
|
| 37 |
+
shifts = torch.arange(15, -1, -1, device=emb.device) # [15, 14, ..., 0]
|
| 38 |
+
# bits_int: (batch, seq_len, D) → unsqueeze → (batch, seq_len, 1, D)
|
| 39 |
+
# shifts: (K,) → view → (1, 1, K, 1)
|
| 40 |
+
bits = ((bits_int.unsqueeze(2) >> shifts.view(1, 1, K, 1)) & 1) # (batch, seq_len, K, D)
|
| 41 |
+
|
| 42 |
+
# 转为计算 dtype 并 detach(编码不参与梯度)
|
| 43 |
+
bits = bits.to(emb.dtype).detach()
|
| 44 |
+
|
| 45 |
+
# reshape → (seq_len*K, batch, D)
|
| 46 |
+
return bits.reshape(batch, seq_len * K, D).permute(1, 0, 2).contiguous()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def fp16_decode(spikes: Tensor, seq_len: int, K: int = 16) -> Tensor:
|
| 50 |
+
"""FP16 精确位解码:从 16 个二值 spike 重建 float16 值。
|
| 51 |
+
|
| 52 |
+
fp16_encode 的精确逆操作。全程可微分——梯度通过 IEEE 754 重建公式
|
| 53 |
+
传到每个 spike 输出,再经 surrogate gradient 传入 SNN。
|
| 54 |
+
|
| 55 |
+
IEEE 754 float16 重建:
|
| 56 |
+
Normal (exp > 0): (-1)^sign * 2^(exp - 15) * (1 + mant_frac)
|
| 57 |
+
Subnormal (exp = 0): (-1)^sign * 2^(-14) * mant_frac
|
| 58 |
+
其中 mant_frac = Σ mant_bit_i * 2^{-(i+1)}, i=0..9
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
spikes: (seq_len*K, batch, D) 二值 {0, 1}(输出神经元的 spike)
|
| 62 |
+
seq_len: token 序列长度
|
| 63 |
+
K: 时间步数(= 16)
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
decoded: (batch, seq_len, D) 连续值
|
| 67 |
+
"""
|
| 68 |
+
batch, D = spikes.shape[1], spikes.shape[2]
|
| 69 |
+
|
| 70 |
+
# (seq_len*K, batch, D) → (batch, seq_len, K, D)
|
| 71 |
+
s = spikes.permute(1, 0, 2).reshape(batch, seq_len, K, D)
|
| 72 |
+
|
| 73 |
+
# ---- Sign: bit 0 ----
|
| 74 |
+
sign = 1.0 - 2.0 * s[:, :, 0, :] # +1 or -1
|
| 75 |
+
|
| 76 |
+
# ---- Exponent: bits 1-5, 加权求和 → 整数 0~31 ----
|
| 77 |
+
exp_weights = torch.tensor(
|
| 78 |
+
[16.0, 8.0, 4.0, 2.0, 1.0],
|
| 79 |
+
device=spikes.device, dtype=spikes.dtype,
|
| 80 |
+
)
|
| 81 |
+
exp_val = (s[:, :, 1:6, :] * exp_weights.view(1, 1, 5, 1)).sum(dim=2)
|
| 82 |
+
|
| 83 |
+
# ---- Mantissa fraction: bits 6-15, 加权求和 → [0, 1) ----
|
| 84 |
+
mant_weights = torch.tensor(
|
| 85 |
+
[2.0 ** (-i) for i in range(1, 11)],
|
| 86 |
+
device=spikes.device, dtype=spikes.dtype,
|
| 87 |
+
)
|
| 88 |
+
mant_frac = (s[:, :, 6:, :] * mant_weights.view(1, 1, 10, 1)).sum(dim=2)
|
| 89 |
+
|
| 90 |
+
# ---- IEEE 754 重建 ----
|
| 91 |
+
# Normal: (-1)^s * 2^(exp-15) * (1 + mant_frac)
|
| 92 |
+
# Subnormal: (-1)^s * 2^(-14) * mant_frac
|
| 93 |
+
is_normal = (exp_val > 0)
|
| 94 |
+
|
| 95 |
+
normal_val = sign * torch.exp2(exp_val - 15.0) * (1.0 + mant_frac)
|
| 96 |
+
subnormal_val = sign * (2.0 ** -14) * mant_frac
|
| 97 |
+
|
| 98 |
+
return torch.where(is_normal, normal_val, subnormal_val)
|
atomic_ops/lateral_inhibition.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LateralInhibition: 侧抑制归一化(Divisive Normalization)
|
| 3 |
+
|
| 4 |
+
神经科学基础:
|
| 5 |
+
Carandini & Heeger (2012) "Normalization as a canonical neural computation"
|
| 6 |
+
侧抑制是大脑中最基本的计算原语之一:兴奋性神经元的活动通过抑制性中间神经元池
|
| 7 |
+
反馈调节,实现增益控制(gain control)。
|
| 8 |
+
|
| 9 |
+
SNN 机制:
|
| 10 |
+
1. 兴奋性群体活动: activity_i = h_i²
|
| 11 |
+
2. 抑制性中间神经元池: pool = mean(activity) = mean(h²)
|
| 12 |
+
3. 分裂抑制 (shunting inhibition): h_norm = h / sqrt(pool + ε)
|
| 13 |
+
4. 增益调制 (gain modulation): output = gain · h_norm
|
| 14 |
+
|
| 15 |
+
替换 RMSNorm:数学操作等价,但在 SNN 框架中有明确的神经科学对应——
|
| 16 |
+
RMSNorm 是 divisive normalization 的特例。
|
| 17 |
+
|
| 18 |
+
Triton fused kernel:
|
| 19 |
+
- 前向: {mean(h²), rsqrt, element-wise mul} → 1 kernel launch
|
| 20 |
+
- 反向: {recompute norm, grad_gain, grad_h} → 1 kernel launch
|
| 21 |
+
- 每行 (D dim) 一个 block,行间并行
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from spikingjelly.activation_based import base
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ============================================================
|
| 32 |
+
# Triton fused kernels
|
| 33 |
+
# ============================================================
|
| 34 |
+
|
| 35 |
+
_SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas'
|
| 36 |
+
if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ:
|
| 37 |
+
os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS
|
| 38 |
+
|
| 39 |
+
_HAS_TRITON = False
|
| 40 |
+
try:
|
| 41 |
+
import triton
|
| 42 |
+
import triton.language as tl
|
| 43 |
+
_HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if _HAS_TRITON:
|
| 49 |
+
|
| 50 |
+
@triton.jit
|
| 51 |
+
def _li_fwd_kernel(
|
| 52 |
+
X_ptr, GAIN_ptr, OUT_ptr,
|
| 53 |
+
stride_row,
|
| 54 |
+
D: tl.constexpr,
|
| 55 |
+
eps: tl.constexpr,
|
| 56 |
+
BLOCK_D: tl.constexpr,
|
| 57 |
+
):
|
| 58 |
+
"""Forward: out = x * rsqrt(mean(x²) + eps) * gain
|
| 59 |
+
|
| 60 |
+
Grid: (num_rows,). Each program processes one row of D elements.
|
| 61 |
+
Computation in float32; storage in input dtype.
|
| 62 |
+
"""
|
| 63 |
+
row = tl.program_id(0)
|
| 64 |
+
cols = tl.arange(0, BLOCK_D)
|
| 65 |
+
mask = cols < D
|
| 66 |
+
off = row * stride_row + cols
|
| 67 |
+
|
| 68 |
+
# Load in float32
|
| 69 |
+
x = tl.load(X_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 70 |
+
gain = tl.load(GAIN_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 71 |
+
|
| 72 |
+
# Inhibitory pool: population activity
|
| 73 |
+
variance = tl.sum(x * x, axis=0) / D
|
| 74 |
+
rrms = 1.0 / tl.sqrt(variance + eps)
|
| 75 |
+
|
| 76 |
+
# Divisive inhibition + gain modulation
|
| 77 |
+
out = x * rrms * gain
|
| 78 |
+
|
| 79 |
+
tl.store(OUT_ptr + off, out, mask=mask)
|
| 80 |
+
|
| 81 |
+
@triton.jit
|
| 82 |
+
def _li_bwd_kernel(
|
| 83 |
+
DOUT_ptr, X_ptr, GAIN_ptr,
|
| 84 |
+
DX_ptr, DGAIN_ptr,
|
| 85 |
+
stride_row,
|
| 86 |
+
D: tl.constexpr,
|
| 87 |
+
eps: tl.constexpr,
|
| 88 |
+
BLOCK_D: tl.constexpr,
|
| 89 |
+
):
|
| 90 |
+
"""Backward: grad_x, grad_gain (per-row, reduced externally).
|
| 91 |
+
|
| 92 |
+
Grid: (num_rows,).
|
| 93 |
+
d_x = rrms * (d_out * gain - x_hat * mean(d_out * gain * x_hat))
|
| 94 |
+
d_gain_row = d_out * x_hat (sum across rows done outside kernel)
|
| 95 |
+
"""
|
| 96 |
+
row = tl.program_id(0)
|
| 97 |
+
cols = tl.arange(0, BLOCK_D)
|
| 98 |
+
mask = cols < D
|
| 99 |
+
off = row * stride_row + cols
|
| 100 |
+
|
| 101 |
+
dout = tl.load(DOUT_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 102 |
+
x = tl.load(X_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 103 |
+
gain = tl.load(GAIN_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 104 |
+
|
| 105 |
+
# Recompute forward (avoid saving intermediate tensors)
|
| 106 |
+
variance = tl.sum(x * x, axis=0) / D
|
| 107 |
+
rrms = 1.0 / tl.sqrt(variance + eps)
|
| 108 |
+
x_hat = x * rrms
|
| 109 |
+
|
| 110 |
+
# grad_gain (per-row contribution)
|
| 111 |
+
dgain = dout * x_hat
|
| 112 |
+
tl.store(DGAIN_ptr + off, dgain, mask=mask)
|
| 113 |
+
|
| 114 |
+
# grad_x: rrms * (dout*gain - x_hat * mean(dout*gain*x_hat))
|
| 115 |
+
dout_gain = dout * gain
|
| 116 |
+
dot = tl.sum(dout_gain * x_hat, axis=0) / D
|
| 117 |
+
dx = (dout_gain - x_hat * dot) * rrms
|
| 118 |
+
|
| 119 |
+
tl.store(DX_ptr + off, dx, mask=mask)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class _LateralInhibitionTriton(torch.autograd.Function):
|
| 123 |
+
"""Triton-accelerated lateral inhibition (divisive normalization)."""
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def forward(ctx, x, gain, eps):
|
| 127 |
+
orig_shape = x.shape
|
| 128 |
+
D = x.shape[-1]
|
| 129 |
+
x_2d = x.reshape(-1, D).contiguous()
|
| 130 |
+
N = x_2d.shape[0]
|
| 131 |
+
|
| 132 |
+
out = torch.empty_like(x_2d)
|
| 133 |
+
|
| 134 |
+
BLOCK_D = triton.next_power_of_2(D)
|
| 135 |
+
_li_fwd_kernel[(N,)](
|
| 136 |
+
x_2d, gain, out,
|
| 137 |
+
x_2d.stride(0),
|
| 138 |
+
D=D, eps=eps, BLOCK_D=BLOCK_D,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
ctx.save_for_backward(x_2d, gain)
|
| 142 |
+
ctx.eps = eps
|
| 143 |
+
ctx.orig_shape = orig_shape
|
| 144 |
+
ctx.N = N
|
| 145 |
+
ctx.D = D
|
| 146 |
+
|
| 147 |
+
return out.reshape(orig_shape)
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def backward(ctx, grad_output):
|
| 151 |
+
x_2d, gain = ctx.saved_tensors
|
| 152 |
+
D = ctx.D
|
| 153 |
+
N = ctx.N
|
| 154 |
+
|
| 155 |
+
grad_2d = grad_output.reshape(N, D).contiguous()
|
| 156 |
+
|
| 157 |
+
dx = torch.empty_like(x_2d)
|
| 158 |
+
dgain_rows = torch.empty_like(x_2d)
|
| 159 |
+
|
| 160 |
+
BLOCK_D = triton.next_power_of_2(D)
|
| 161 |
+
_li_bwd_kernel[(N,)](
|
| 162 |
+
grad_2d, x_2d, gain,
|
| 163 |
+
dx, dgain_rows,
|
| 164 |
+
x_2d.stride(0),
|
| 165 |
+
D=D, eps=ctx.eps, BLOCK_D=BLOCK_D,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Reduce per-row dgain across all rows
|
| 169 |
+
dgain = dgain_rows.sum(dim=0)
|
| 170 |
+
|
| 171 |
+
return dx.reshape(ctx.orig_shape), dgain, None
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ============================================================
|
| 175 |
+
# Public module
|
| 176 |
+
# ============================================================
|
| 177 |
+
|
| 178 |
+
class LateralInhibition(base.MemoryModule):
|
| 179 |
+
"""
|
| 180 |
+
侧抑制归一化层(Divisive Normalization)。
|
| 181 |
+
|
| 182 |
+
通过抑制性中间神经元池实现增益控制。
|
| 183 |
+
|
| 184 |
+
数学:
|
| 185 |
+
pool = mean(h², dim=-1) # 抑制性池:群体活动水平
|
| 186 |
+
h_norm = h / sqrt(pool + ε) # 分裂抑制 (shunting inhibition)
|
| 187 |
+
output = gain · h_norm # 增益调制 (gain modulation)
|
| 188 |
+
|
| 189 |
+
等价于 RMSNorm,但在 SNN 框架中对应 divisive normalization
|
| 190 |
+
(Carandini & Heeger, 2012),是神经科学中最基本的计算原语之一。
|
| 191 |
+
|
| 192 |
+
CUDA: Triton fused kernel(前向+反向各 1 次 launch)
|
| 193 |
+
CPU: PyTorch fallback
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
dim: 特征维度(D)
|
| 197 |
+
eps: 数值稳定性
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.gain = nn.Parameter(torch.ones(dim))
|
| 203 |
+
self.eps = eps
|
| 204 |
+
self.dim = dim
|
| 205 |
+
|
| 206 |
+
def forward(self, h: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
if _HAS_TRITON and h.is_cuda:
|
| 208 |
+
return _LateralInhibitionTriton.apply(h, self.gain, self.eps)
|
| 209 |
+
# PyTorch fallback
|
| 210 |
+
variance = h.pow(2).mean(-1, keepdim=True)
|
| 211 |
+
h_norm = h * torch.rsqrt(variance + self.eps)
|
| 212 |
+
return self.gain * h_norm
|
| 213 |
+
|
| 214 |
+
def extra_repr(self):
|
| 215 |
+
return f'dim={self.dim}, eps={self.eps}'
|
atomic_ops/parallel_scan.py
ADDED
|
@@ -0,0 +1,829 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Parallel Scan 工具函数:SNN 线性递推的高效并行求解
|
| 3 |
+
|
| 4 |
+
实现三层后端:
|
| 5 |
+
1. Fused PLIF kernel(默认,CUDA + Sigmoid surrogate):
|
| 6 |
+
单 kernel 完成 PLIF 前向(scan + spike + soft reset)和反向(surrogate gradient)
|
| 7 |
+
· per-element beta/v_th: _fused_plif_fwd_kernel / _fused_plif_bwd_kernel
|
| 8 |
+
· row-param beta/v_th: _fused_plif_fwd_rowparam_kernel / _fused_plif_bwd_rowparam_kernel
|
| 9 |
+
2. Triton linear_recurrence(CUDA,非 Sigmoid 或无 surrogate):
|
| 10 |
+
列级并行 scan,O(K) 工作量,1 次 kernel launch
|
| 11 |
+
3. Hillis-Steele parallel scan(CPU 回退):O(K log K) 工作量
|
| 12 |
+
|
| 13 |
+
线性递推:
|
| 14 |
+
V[k] = a[k] * V[k-1] + b[k], V[-1] = v_init
|
| 15 |
+
|
| 16 |
+
PLIF 神经元动力学:
|
| 17 |
+
V_pre[k] = beta[k] * V_post[k-1] + u[k]
|
| 18 |
+
s[k] = Θ(V_pre[k] - v_th[k])
|
| 19 |
+
V_post[k] = V_pre[k] - v_th[k] * s[k]
|
| 20 |
+
|
| 21 |
+
数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ============================================================
|
| 29 |
+
# Triton fused recurrence kernels
|
| 30 |
+
# ============================================================
|
| 31 |
+
|
| 32 |
+
# DGX Spark (GB10, sm_121a): Triton 3.5.1 自带 ptxas 不支持 sm_121a,
|
| 33 |
+
# 需要使用系统 CUDA 13.0 的 ptxas
|
| 34 |
+
_SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas'
|
| 35 |
+
if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ:
|
| 36 |
+
os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS
|
| 37 |
+
|
| 38 |
+
_HAS_TRITON = False
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
_HAS_TRITON = True
|
| 43 |
+
except ImportError:
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
if _HAS_TRITON:
|
| 47 |
+
|
| 48 |
+
@triton.jit
|
| 49 |
+
def _fwd_recurrence_kernel(
|
| 50 |
+
A_ptr, B_ptr, INIT_ptr, OUT_ptr,
|
| 51 |
+
K, num_cols,
|
| 52 |
+
BLOCK: tl.constexpr,
|
| 53 |
+
):
|
| 54 |
+
"""Forward: V[k] = A[k]*V[k-1] + B[k], V[-1] = init.
|
| 55 |
+
|
| 56 |
+
Grid: (ceil(num_cols / BLOCK),)
|
| 57 |
+
Each program processes BLOCK columns across all K sequential steps.
|
| 58 |
+
Accumulation in float32; storage in input dtype.
|
| 59 |
+
"""
|
| 60 |
+
pid = tl.program_id(0)
|
| 61 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 62 |
+
mask = cols < num_cols
|
| 63 |
+
|
| 64 |
+
v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 65 |
+
|
| 66 |
+
for k in range(K):
|
| 67 |
+
off = k * num_cols + cols
|
| 68 |
+
a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 69 |
+
b = tl.load(B_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 70 |
+
v = a * v + b
|
| 71 |
+
tl.store(OUT_ptr + off, v, mask=mask)
|
| 72 |
+
|
| 73 |
+
@triton.jit
|
| 74 |
+
def _bwd_recurrence_kernel(
|
| 75 |
+
A_ptr, V_ptr, INIT_ptr, GRAD_OUT_ptr,
|
| 76 |
+
GRAD_A_ptr, GRAD_B_ptr, GRAD_INIT_ptr,
|
| 77 |
+
K, num_cols,
|
| 78 |
+
BLOCK: tl.constexpr,
|
| 79 |
+
):
|
| 80 |
+
"""Backward for V[k] = A[k]*V[k-1] + B[k].
|
| 81 |
+
|
| 82 |
+
Reverse accumulation (k from K-1 down to 0):
|
| 83 |
+
g = 0
|
| 84 |
+
for k = K-1, ..., 0:
|
| 85 |
+
g += grad_out[k]
|
| 86 |
+
grad_B[k] = g
|
| 87 |
+
grad_A[k] = g * V[k-1] (V[-1] = init)
|
| 88 |
+
g = g * A[k]
|
| 89 |
+
grad_init = g
|
| 90 |
+
"""
|
| 91 |
+
pid = tl.program_id(0)
|
| 92 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 93 |
+
mask = cols < num_cols
|
| 94 |
+
|
| 95 |
+
g = tl.zeros([BLOCK], dtype=tl.float32)
|
| 96 |
+
|
| 97 |
+
for k_rev in range(K):
|
| 98 |
+
k = K - 1 - k_rev
|
| 99 |
+
off = k * num_cols + cols
|
| 100 |
+
|
| 101 |
+
dV = tl.load(GRAD_OUT_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 102 |
+
g = g + dV
|
| 103 |
+
|
| 104 |
+
tl.store(GRAD_B_ptr + off, g, mask=mask)
|
| 105 |
+
|
| 106 |
+
if k > 0:
|
| 107 |
+
v_prev = tl.load(
|
| 108 |
+
V_ptr + (k - 1) * num_cols + cols,
|
| 109 |
+
mask=mask, other=0.0,
|
| 110 |
+
).to(tl.float32)
|
| 111 |
+
else:
|
| 112 |
+
v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 113 |
+
tl.store(GRAD_A_ptr + off, g * v_prev, mask=mask)
|
| 114 |
+
|
| 115 |
+
a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 116 |
+
g = g * a
|
| 117 |
+
|
| 118 |
+
tl.store(GRAD_INIT_ptr + cols, g, mask=mask)
|
| 119 |
+
|
| 120 |
+
class _TritonLinearRecurrence(torch.autograd.Function):
|
| 121 |
+
"""Fused Triton linear recurrence: V[k] = A[k]*V[k-1] + B[k]."""
|
| 122 |
+
|
| 123 |
+
_BLOCK = 128
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def forward(ctx, beta, u, v_init):
|
| 127 |
+
beta_c = beta.contiguous()
|
| 128 |
+
u_c = u.contiguous()
|
| 129 |
+
v_init_c = v_init.contiguous()
|
| 130 |
+
|
| 131 |
+
K = beta_c.shape[0]
|
| 132 |
+
num_cols = beta_c[0].numel()
|
| 133 |
+
V = torch.empty_like(u_c)
|
| 134 |
+
|
| 135 |
+
BLOCK = _TritonLinearRecurrence._BLOCK
|
| 136 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 137 |
+
|
| 138 |
+
_fwd_recurrence_kernel[grid](
|
| 139 |
+
beta_c, u_c, v_init_c, V,
|
| 140 |
+
K, num_cols,
|
| 141 |
+
BLOCK=BLOCK,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
|
| 145 |
+
ctx.save_for_backward(beta_c, V, v_init_c)
|
| 146 |
+
ctx.K = K
|
| 147 |
+
ctx.num_cols = num_cols
|
| 148 |
+
|
| 149 |
+
return V
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def backward(ctx, grad_V):
|
| 153 |
+
beta, V, v_init = ctx.saved_tensors
|
| 154 |
+
grad_V_c = grad_V.contiguous()
|
| 155 |
+
|
| 156 |
+
K = ctx.K
|
| 157 |
+
num_cols = ctx.num_cols
|
| 158 |
+
|
| 159 |
+
grad_beta = torch.empty_like(beta)
|
| 160 |
+
grad_u = torch.empty_like(beta)
|
| 161 |
+
grad_v_init = torch.empty_like(v_init)
|
| 162 |
+
|
| 163 |
+
BLOCK = _TritonLinearRecurrence._BLOCK
|
| 164 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 165 |
+
|
| 166 |
+
_bwd_recurrence_kernel[grid](
|
| 167 |
+
beta, V, v_init, grad_V_c,
|
| 168 |
+
grad_beta, grad_u, grad_v_init,
|
| 169 |
+
K, num_cols,
|
| 170 |
+
BLOCK=BLOCK,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return grad_beta, grad_u, grad_v_init
|
| 174 |
+
|
| 175 |
+
# ============================================================
|
| 176 |
+
# Fused PLIF forward/backward kernels
|
| 177 |
+
# ============================================================
|
| 178 |
+
|
| 179 |
+
@triton.jit
|
| 180 |
+
def _fused_plif_fwd_kernel(
|
| 181 |
+
BETA_ptr, U_ptr, VTH_ptr, INIT_ptr,
|
| 182 |
+
SPIKE_ptr, VPOST_ptr,
|
| 183 |
+
K, num_cols,
|
| 184 |
+
BLOCK: tl.constexpr,
|
| 185 |
+
):
|
| 186 |
+
"""Fused PLIF forward: single-pass sequential scan with inline spike + soft reset.
|
| 187 |
+
|
| 188 |
+
Exact computation — sequential scan IS the ground truth.
|
| 189 |
+
Replaces the 3-phase approach (linear scan + spike iteration + correction).
|
| 190 |
+
|
| 191 |
+
Per column (parallel across batch*D):
|
| 192 |
+
v = v_init
|
| 193 |
+
for k = 0..K-1:
|
| 194 |
+
v_pre = beta[k]*v + u[k]
|
| 195 |
+
spike[k] = Θ(v_pre - v_th[k])
|
| 196 |
+
v = v_pre - v_th[k]*spike[k]
|
| 197 |
+
"""
|
| 198 |
+
pid = tl.program_id(0)
|
| 199 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 200 |
+
mask = cols < num_cols
|
| 201 |
+
|
| 202 |
+
v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 203 |
+
|
| 204 |
+
for k in range(K):
|
| 205 |
+
off = k * num_cols + cols
|
| 206 |
+
beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 207 |
+
u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 208 |
+
vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 209 |
+
|
| 210 |
+
v_pre = beta * v + u
|
| 211 |
+
spike = tl.where(v_pre >= vth, 1.0, 0.0)
|
| 212 |
+
v = v_pre - vth * spike # soft reset
|
| 213 |
+
|
| 214 |
+
tl.store(SPIKE_ptr + off, spike, mask=mask)
|
| 215 |
+
tl.store(VPOST_ptr + off, v, mask=mask)
|
| 216 |
+
|
| 217 |
+
@triton.jit
|
| 218 |
+
def _fused_plif_bwd_kernel(
|
| 219 |
+
BETA_ptr, VTH_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr,
|
| 220 |
+
GRAD_SPIKE_ptr, GRAD_VPOST_ptr,
|
| 221 |
+
GRAD_BETA_ptr, GRAD_U_ptr, GRAD_VTH_ptr, GRAD_INIT_ptr,
|
| 222 |
+
K, num_cols, ALPHA,
|
| 223 |
+
BLOCK: tl.constexpr,
|
| 224 |
+
):
|
| 225 |
+
"""Fused PLIF backward: single reverse pass with Sigmoid surrogate gradient.
|
| 226 |
+
|
| 227 |
+
V_pre[k] = V_post[k] + v_th[k]*spike[k] (reconstructed)
|
| 228 |
+
surrogate_grad(x) = alpha * sigmoid(alpha*x) * (1 - sigmoid(alpha*x))
|
| 229 |
+
where x = V_pre[k] - v_th[k] = V_post[k] - v_th[k]*(1 - spike[k])
|
| 230 |
+
|
| 231 |
+
Reverse accumulation:
|
| 232 |
+
acc = 0
|
| 233 |
+
for k = K-1 downto 0:
|
| 234 |
+
total_gV = grad_V_post[k] + acc
|
| 235 |
+
sg = surrogate_grad(V_pre[k] - v_th[k])
|
| 236 |
+
grad_v_pre = grad_spike[k]*sg + total_gV
|
| 237 |
+
grad_beta[k] = grad_v_pre * V_post[k-1]
|
| 238 |
+
grad_u[k] = grad_v_pre
|
| 239 |
+
grad_v_th[k] = -grad_spike[k]*sg - total_gV*spike[k]
|
| 240 |
+
acc = grad_v_pre * beta[k]
|
| 241 |
+
grad_v_init = acc
|
| 242 |
+
"""
|
| 243 |
+
pid = tl.program_id(0)
|
| 244 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 245 |
+
mask = cols < num_cols
|
| 246 |
+
|
| 247 |
+
acc = tl.zeros([BLOCK], dtype=tl.float32)
|
| 248 |
+
|
| 249 |
+
for k_rev in range(K):
|
| 250 |
+
k = K - 1 - k_rev
|
| 251 |
+
off = k * num_cols + cols
|
| 252 |
+
|
| 253 |
+
beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 254 |
+
vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 255 |
+
v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 256 |
+
spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 257 |
+
|
| 258 |
+
g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 259 |
+
g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 260 |
+
|
| 261 |
+
# V_post[k-1]
|
| 262 |
+
if k > 0:
|
| 263 |
+
v_prev = tl.load(
|
| 264 |
+
VPOST_ptr + (k - 1) * num_cols + cols,
|
| 265 |
+
mask=mask, other=0.0,
|
| 266 |
+
).to(tl.float32)
|
| 267 |
+
else:
|
| 268 |
+
v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 269 |
+
|
| 270 |
+
# Sigmoid surrogate gradient
|
| 271 |
+
x = v_post - vth * (1.0 - spike) # = V_pre - v_th
|
| 272 |
+
neg_ax = -ALPHA * x
|
| 273 |
+
neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax) # prevent exp overflow
|
| 274 |
+
sig = 1.0 / (1.0 + tl.exp(neg_ax))
|
| 275 |
+
sg = ALPHA * sig * (1.0 - sig)
|
| 276 |
+
|
| 277 |
+
total_gV = g_V + acc
|
| 278 |
+
grad_v_pre = g_s * sg + total_gV
|
| 279 |
+
|
| 280 |
+
tl.store(GRAD_BETA_ptr + off, grad_v_pre * v_prev, mask=mask)
|
| 281 |
+
tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask)
|
| 282 |
+
tl.store(GRAD_VTH_ptr + off, -g_s * sg - total_gV * spike, mask=mask)
|
| 283 |
+
|
| 284 |
+
acc = grad_v_pre * beta
|
| 285 |
+
|
| 286 |
+
tl.store(GRAD_INIT_ptr + cols, acc, mask=mask)
|
| 287 |
+
|
| 288 |
+
# ============================================================
|
| 289 |
+
# Fused PLIF kernels with row-parameter beta/v_th
|
| 290 |
+
# (constant across K steps — e.g., ParametricLIFNode scalars)
|
| 291 |
+
# ============================================================
|
| 292 |
+
|
| 293 |
+
@triton.jit
|
| 294 |
+
def _fused_plif_fwd_rowparam_kernel(
|
| 295 |
+
BETA_ROW_ptr, U_ptr, VTH_ROW_ptr, INIT_ptr,
|
| 296 |
+
SPIKE_ptr, VPOST_ptr,
|
| 297 |
+
K, num_cols,
|
| 298 |
+
BLOCK: tl.constexpr,
|
| 299 |
+
):
|
| 300 |
+
"""Fused PLIF forward with row-parameter beta and v_th.
|
| 301 |
+
|
| 302 |
+
beta and v_th are (*shape) — constant across K steps, loaded once into registers.
|
| 303 |
+
Reduces global memory reads from 3 per step (beta, u, v_th) to 1 (u only).
|
| 304 |
+
"""
|
| 305 |
+
pid = tl.program_id(0)
|
| 306 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 307 |
+
mask = cols < num_cols
|
| 308 |
+
|
| 309 |
+
v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 310 |
+
beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 311 |
+
vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 312 |
+
|
| 313 |
+
for k in range(K):
|
| 314 |
+
off = k * num_cols + cols
|
| 315 |
+
u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
v_pre = beta * v + u
|
| 318 |
+
spike = tl.where(v_pre >= vth, 1.0, 0.0)
|
| 319 |
+
v = v_pre - vth * spike
|
| 320 |
+
|
| 321 |
+
tl.store(SPIKE_ptr + off, spike, mask=mask)
|
| 322 |
+
tl.store(VPOST_ptr + off, v, mask=mask)
|
| 323 |
+
|
| 324 |
+
@triton.jit
|
| 325 |
+
def _fused_plif_bwd_rowparam_kernel(
|
| 326 |
+
BETA_ROW_ptr, VTH_ROW_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr,
|
| 327 |
+
GRAD_SPIKE_ptr, GRAD_VPOST_ptr,
|
| 328 |
+
GRAD_BETA_ROW_ptr, GRAD_U_ptr, GRAD_VTH_ROW_ptr, GRAD_INIT_ptr,
|
| 329 |
+
K, num_cols, ALPHA,
|
| 330 |
+
BLOCK: tl.constexpr,
|
| 331 |
+
):
|
| 332 |
+
"""Fused PLIF backward with row-parameter beta/v_th.
|
| 333 |
+
|
| 334 |
+
Gradients for beta and v_th are accumulated over K steps (reduction in registers).
|
| 335 |
+
Returns grad_beta_row (*shape) and grad_v_th_row (*shape) instead of per-step gradients.
|
| 336 |
+
"""
|
| 337 |
+
pid = tl.program_id(0)
|
| 338 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 339 |
+
mask = cols < num_cols
|
| 340 |
+
|
| 341 |
+
beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 342 |
+
vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 343 |
+
|
| 344 |
+
acc = tl.zeros([BLOCK], dtype=tl.float32)
|
| 345 |
+
acc_grad_beta = tl.zeros([BLOCK], dtype=tl.float32)
|
| 346 |
+
acc_grad_vth = tl.zeros([BLOCK], dtype=tl.float32)
|
| 347 |
+
|
| 348 |
+
for k_rev in range(K):
|
| 349 |
+
k = K - 1 - k_rev
|
| 350 |
+
off = k * num_cols + cols
|
| 351 |
+
|
| 352 |
+
v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 353 |
+
spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 354 |
+
|
| 355 |
+
g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 356 |
+
g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 357 |
+
|
| 358 |
+
if k > 0:
|
| 359 |
+
v_prev = tl.load(
|
| 360 |
+
VPOST_ptr + (k - 1) * num_cols + cols,
|
| 361 |
+
mask=mask, other=0.0,
|
| 362 |
+
).to(tl.float32)
|
| 363 |
+
else:
|
| 364 |
+
v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 365 |
+
|
| 366 |
+
# Sigmoid surrogate gradient
|
| 367 |
+
x = v_post - vth * (1.0 - spike)
|
| 368 |
+
neg_ax = -ALPHA * x
|
| 369 |
+
neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax)
|
| 370 |
+
sig = 1.0 / (1.0 + tl.exp(neg_ax))
|
| 371 |
+
sg = ALPHA * sig * (1.0 - sig)
|
| 372 |
+
|
| 373 |
+
total_gV = g_V + acc
|
| 374 |
+
grad_v_pre = g_s * sg + total_gV
|
| 375 |
+
|
| 376 |
+
tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask)
|
| 377 |
+
|
| 378 |
+
# Accumulate gradients for row parameters (reduction over K in registers)
|
| 379 |
+
acc_grad_beta += grad_v_pre * v_prev
|
| 380 |
+
acc_grad_vth += -g_s * sg - total_gV * spike
|
| 381 |
+
|
| 382 |
+
acc = grad_v_pre * beta
|
| 383 |
+
|
| 384 |
+
tl.store(GRAD_INIT_ptr + cols, acc, mask=mask)
|
| 385 |
+
tl.store(GRAD_BETA_ROW_ptr + cols, acc_grad_beta, mask=mask)
|
| 386 |
+
tl.store(GRAD_VTH_ROW_ptr + cols, acc_grad_vth, mask=mask)
|
| 387 |
+
|
| 388 |
+
class _TritonPLIFRowParamForward(torch.autograd.Function):
|
| 389 |
+
"""Fused Triton PLIF with row-parameter beta/v_th.
|
| 390 |
+
|
| 391 |
+
For neurons with constant beta/v_th across K steps (ParametricLIFNode).
|
| 392 |
+
Eliminates expand+contiguous for beta/v_th tensors, reduces memory I/O by ~40%.
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
_BLOCK = 128
|
| 396 |
+
|
| 397 |
+
@staticmethod
|
| 398 |
+
def forward(ctx, beta_row, u, v_th_row, v_init, alpha):
|
| 399 |
+
beta_row_c = beta_row.contiguous()
|
| 400 |
+
u_c = u.contiguous()
|
| 401 |
+
v_th_row_c = v_th_row.contiguous()
|
| 402 |
+
v_init_c = v_init.contiguous()
|
| 403 |
+
|
| 404 |
+
K = u_c.shape[0]
|
| 405 |
+
num_cols = u_c[0].numel()
|
| 406 |
+
|
| 407 |
+
spike = torch.empty_like(u_c)
|
| 408 |
+
V_post = torch.empty_like(u_c)
|
| 409 |
+
|
| 410 |
+
BLOCK = _TritonPLIFRowParamForward._BLOCK
|
| 411 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 412 |
+
|
| 413 |
+
_fused_plif_fwd_rowparam_kernel[grid](
|
| 414 |
+
beta_row_c, u_c, v_th_row_c, v_init_c,
|
| 415 |
+
spike, V_post,
|
| 416 |
+
K, num_cols,
|
| 417 |
+
BLOCK=BLOCK,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
if any(ctx.needs_input_grad[:4]):
|
| 421 |
+
ctx.save_for_backward(beta_row_c, v_th_row_c, v_init_c, V_post, spike)
|
| 422 |
+
ctx.K = K
|
| 423 |
+
ctx.num_cols = num_cols
|
| 424 |
+
ctx.alpha = alpha
|
| 425 |
+
|
| 426 |
+
return spike, V_post
|
| 427 |
+
|
| 428 |
+
@staticmethod
|
| 429 |
+
def backward(ctx, grad_spike, grad_V_post):
|
| 430 |
+
beta_row, v_th_row, v_init, V_post, spike = ctx.saved_tensors
|
| 431 |
+
K = ctx.K
|
| 432 |
+
num_cols = ctx.num_cols
|
| 433 |
+
alpha = ctx.alpha
|
| 434 |
+
|
| 435 |
+
if grad_spike is None:
|
| 436 |
+
grad_spike = torch.zeros_like(spike)
|
| 437 |
+
if grad_V_post is None:
|
| 438 |
+
grad_V_post = torch.zeros_like(V_post)
|
| 439 |
+
|
| 440 |
+
grad_spike_c = grad_spike.contiguous()
|
| 441 |
+
grad_V_post_c = grad_V_post.contiguous()
|
| 442 |
+
|
| 443 |
+
grad_beta_row = torch.empty_like(beta_row)
|
| 444 |
+
grad_u = torch.empty_like(V_post)
|
| 445 |
+
grad_v_th_row = torch.empty_like(v_th_row)
|
| 446 |
+
grad_v_init = torch.empty_like(v_init)
|
| 447 |
+
|
| 448 |
+
BLOCK = _TritonPLIFRowParamForward._BLOCK
|
| 449 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 450 |
+
|
| 451 |
+
_fused_plif_bwd_rowparam_kernel[grid](
|
| 452 |
+
beta_row, v_th_row, v_init, V_post, spike,
|
| 453 |
+
grad_spike_c, grad_V_post_c,
|
| 454 |
+
grad_beta_row, grad_u, grad_v_th_row, grad_v_init,
|
| 455 |
+
K, num_cols, float(alpha),
|
| 456 |
+
BLOCK=BLOCK,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
return grad_beta_row, grad_u, grad_v_th_row, grad_v_init, None
|
| 460 |
+
|
| 461 |
+
class _TritonPLIFForward(torch.autograd.Function):
|
| 462 |
+
"""Fused Triton PLIF forward + backward.
|
| 463 |
+
|
| 464 |
+
Single-pass sequential scan replaces the 3-phase approach:
|
| 465 |
+
Phase 1 (linear scan) + Phase 2 (spike iteration) + Phase 3 (correction)
|
| 466 |
+
→ 1 fused kernel with inline spike detection + soft reset
|
| 467 |
+
|
| 468 |
+
Advantages:
|
| 469 |
+
- 1 kernel launch (vs 3-4 launches + ~10 element-wise ops)
|
| 470 |
+
- Exact computation (no iteration convergence issues)
|
| 471 |
+
- Less memory (no intermediate V_L, delta_S, delta_S_prev)
|
| 472 |
+
- Higher precision (fp32 accumulation, no bf16 intermediate store/load)
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
_BLOCK = 128
|
| 476 |
+
|
| 477 |
+
@staticmethod
|
| 478 |
+
def forward(ctx, beta, u, v_th, v_init, alpha):
|
| 479 |
+
beta_c = beta.contiguous()
|
| 480 |
+
u_c = u.contiguous()
|
| 481 |
+
v_th_c = v_th.contiguous()
|
| 482 |
+
v_init_c = v_init.contiguous()
|
| 483 |
+
|
| 484 |
+
K = beta_c.shape[0]
|
| 485 |
+
num_cols = beta_c[0].numel()
|
| 486 |
+
|
| 487 |
+
spike = torch.empty_like(u_c)
|
| 488 |
+
V_post = torch.empty_like(u_c)
|
| 489 |
+
|
| 490 |
+
BLOCK = _TritonPLIFForward._BLOCK
|
| 491 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 492 |
+
|
| 493 |
+
_fused_plif_fwd_kernel[grid](
|
| 494 |
+
beta_c, u_c, v_th_c, v_init_c,
|
| 495 |
+
spike, V_post,
|
| 496 |
+
K, num_cols,
|
| 497 |
+
BLOCK=BLOCK,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if any(ctx.needs_input_grad[:4]):
|
| 501 |
+
ctx.save_for_backward(beta_c, v_th_c, v_init_c, V_post, spike)
|
| 502 |
+
ctx.K = K
|
| 503 |
+
ctx.num_cols = num_cols
|
| 504 |
+
ctx.alpha = alpha
|
| 505 |
+
|
| 506 |
+
return spike, V_post
|
| 507 |
+
|
| 508 |
+
@staticmethod
|
| 509 |
+
def backward(ctx, grad_spike, grad_V_post):
|
| 510 |
+
beta, v_th, v_init, V_post, spike = ctx.saved_tensors
|
| 511 |
+
K = ctx.K
|
| 512 |
+
num_cols = ctx.num_cols
|
| 513 |
+
alpha = ctx.alpha
|
| 514 |
+
|
| 515 |
+
if grad_spike is None:
|
| 516 |
+
grad_spike = torch.zeros_like(spike)
|
| 517 |
+
if grad_V_post is None:
|
| 518 |
+
grad_V_post = torch.zeros_like(V_post)
|
| 519 |
+
|
| 520 |
+
grad_spike_c = grad_spike.contiguous()
|
| 521 |
+
grad_V_post_c = grad_V_post.contiguous()
|
| 522 |
+
|
| 523 |
+
grad_beta = torch.empty_like(beta)
|
| 524 |
+
grad_u = torch.empty_like(beta)
|
| 525 |
+
grad_v_th = torch.empty_like(v_th)
|
| 526 |
+
grad_v_init = torch.empty_like(v_init)
|
| 527 |
+
|
| 528 |
+
BLOCK = _TritonPLIFForward._BLOCK
|
| 529 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 530 |
+
|
| 531 |
+
_fused_plif_bwd_kernel[grid](
|
| 532 |
+
beta, v_th, v_init, V_post, spike,
|
| 533 |
+
grad_spike_c, grad_V_post_c,
|
| 534 |
+
grad_beta, grad_u, grad_v_th, grad_v_init,
|
| 535 |
+
K, num_cols, float(alpha),
|
| 536 |
+
BLOCK=BLOCK,
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
return grad_beta, grad_u, grad_v_th, grad_v_init, None
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
# ============================================================
|
| 543 |
+
# Hillis-Steele parallel prefix scan (CPU fallback)
|
| 544 |
+
# ============================================================
|
| 545 |
+
|
| 546 |
+
def hillis_steele_scan(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 547 |
+
"""
|
| 548 |
+
Hillis-Steele 并行前缀扫描:计算仿射映射序列的所有前缀复合。
|
| 549 |
+
|
| 550 |
+
给定仿射映射 f_k(x) = a[k] * x + b[k], k = 0, ..., K-1,
|
| 551 |
+
计算前缀复合 F_k = f_k ∘ f_{k-1} ∘ ... ∘ f_0,
|
| 552 |
+
使得 V[k] = F_k(v_init) = A[k] * v_init + B[k]。
|
| 553 |
+
|
| 554 |
+
复合规则: (a2, b2) ∘ (a1, b1) = (a2 * a1, a2 * b1 + b2)
|
| 555 |
+
|
| 556 |
+
实现使用 torch.cat 重建张量(无原地操作),完全兼容 autograd。
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
a: (K, *shape) — 乘性系数(如 β)
|
| 560 |
+
b: (K, *shape) — 加性项(如 α·I)
|
| 561 |
+
|
| 562 |
+
Returns:
|
| 563 |
+
A: (K, *shape) — 前缀积 A[k] = ∏_{j=0}^{k} a[j]
|
| 564 |
+
B: (K, *shape) — 前缀和 B[k] 使得 V[k] = A[k] * v_init + B[k]
|
| 565 |
+
|
| 566 |
+
并行深度: O(log K)
|
| 567 |
+
工作量: O(K * log K)
|
| 568 |
+
"""
|
| 569 |
+
K = a.shape[0]
|
| 570 |
+
A = a
|
| 571 |
+
B = b
|
| 572 |
+
|
| 573 |
+
d = 1
|
| 574 |
+
while d < K:
|
| 575 |
+
A_new_tail = A[d:] * A[:-d]
|
| 576 |
+
B_new_tail = A[d:] * B[:-d] + B[d:]
|
| 577 |
+
|
| 578 |
+
A = torch.cat([A[:d], A_new_tail], dim=0)
|
| 579 |
+
B = torch.cat([B[:d], B_new_tail], dim=0)
|
| 580 |
+
|
| 581 |
+
d *= 2
|
| 582 |
+
|
| 583 |
+
return A, B
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
# ============================================================
|
| 587 |
+
# Public API: linear_recurrence
|
| 588 |
+
# ============================================================
|
| 589 |
+
|
| 590 |
+
def linear_recurrence(beta: torch.Tensor, u: torch.Tensor, v_init: torch.Tensor) -> torch.Tensor:
|
| 591 |
+
"""
|
| 592 |
+
求解线性递推: V[k] = beta[k] * V[k-1] + u[k], V[-1] = v_init
|
| 593 |
+
|
| 594 |
+
CUDA 后端: Triton fused kernel(1 次 kernel launch,O(K) 工作量)
|
| 595 |
+
CPU 后端: Hillis-Steele parallel scan(O(K log K) 工作量)
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
beta: (K, *shape) — 衰减系数,值域 (0, 1)
|
| 599 |
+
u: (K, *shape) — 输入项
|
| 600 |
+
v_init: (*shape) — 初始状态
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
V: (K, *shape) — 所有 K 步的状态
|
| 604 |
+
"""
|
| 605 |
+
if _HAS_TRITON and beta.is_cuda:
|
| 606 |
+
return _TritonLinearRecurrence.apply(beta, u, v_init)
|
| 607 |
+
# CPU fallback
|
| 608 |
+
A, B = hillis_steele_scan(beta, u)
|
| 609 |
+
V = A * v_init.unsqueeze(0) + B
|
| 610 |
+
return V
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
# ============================================================
|
| 614 |
+
# PLIF parallel forward (with spike iteration)
|
| 615 |
+
# ============================================================
|
| 616 |
+
|
| 617 |
+
def plif_parallel_forward(
|
| 618 |
+
beta: torch.Tensor,
|
| 619 |
+
u: torch.Tensor,
|
| 620 |
+
v_th: torch.Tensor,
|
| 621 |
+
v_init: torch.Tensor,
|
| 622 |
+
max_iter: int = 3,
|
| 623 |
+
surrogate_function=None,
|
| 624 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 625 |
+
"""
|
| 626 |
+
PLIF 神经元的并行前向传播(soft reset,surrogate gradient 兼容)。
|
| 627 |
+
|
| 628 |
+
求解:
|
| 629 |
+
V_pre[k] = beta[k] * V_post[k-1] + u[k]
|
| 630 |
+
s[k] = Θ(V_pre[k] - v_th[k])
|
| 631 |
+
V_post[k] = V_pre[k] - v_th[k] * s[k]
|
| 632 |
+
|
| 633 |
+
方法:
|
| 634 |
+
Phase 1: 线性轨迹 parallel scan(有梯度)
|
| 635 |
+
Phase 2: spike 不动点迭代(detach,确定离散 spike pattern)
|
| 636 |
+
Phase 3: 用 converged spike pattern 重算 V_post(有梯度),
|
| 637 |
+
surrogate_function(V_pre - v_th) 生成可微 spike 输出
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
beta: (K, *shape) — 衰减系数
|
| 641 |
+
u: (K, *shape) — 输入 α·I
|
| 642 |
+
v_th: (K, *shape) — 动态阈值
|
| 643 |
+
v_init: (*shape) — 初始膜电位
|
| 644 |
+
max_iter: spike 不动点迭代次数上限
|
| 645 |
+
surrogate_function: SpikingJelly surrogate gradient 函数(如 surrogate.Sigmoid(alpha=4.0))
|
| 646 |
+
None 时退化为硬阈值(无梯度)
|
| 647 |
+
|
| 648 |
+
Returns:
|
| 649 |
+
spike: (K, *shape) — spike 模式(有 surrogate gradient)
|
| 650 |
+
V_post: (K, *shape) — 发放后膜电位
|
| 651 |
+
V_pre: (K, *shape) — 发放前膜电位(fused path 返回 None)
|
| 652 |
+
"""
|
| 653 |
+
# Fused Triton path: single-pass sequential scan (exact, no iteration)
|
| 654 |
+
# Replaces 3-phase approach with 1 kernel launch — ~3x faster forward, ~5x faster backward
|
| 655 |
+
if (_HAS_TRITON and beta.is_cuda and surrogate_function is not None
|
| 656 |
+
and hasattr(surrogate_function, 'alpha')
|
| 657 |
+
and type(surrogate_function).__name__ == 'Sigmoid'):
|
| 658 |
+
alpha = float(surrogate_function.alpha)
|
| 659 |
+
spike, V_post = _TritonPLIFForward.apply(beta, u, v_th, v_init, alpha)
|
| 660 |
+
return spike, V_post, None
|
| 661 |
+
|
| 662 |
+
# Fallback: 3-phase approach (CPU, non-Sigmoid surrogates, or no surrogate)
|
| 663 |
+
# Phase 1: 线性轨迹 V_L (假设从不发放)
|
| 664 |
+
V_L = linear_recurrence(beta, u, v_init) # (K, *shape)
|
| 665 |
+
|
| 666 |
+
# Phase 2: Spike 不动点迭代(全部 detach,不建立梯度图)
|
| 667 |
+
# 目的:确定哪些神经元在哪些步发放(离散决策)
|
| 668 |
+
with torch.no_grad():
|
| 669 |
+
V_L_det = V_L.detach()
|
| 670 |
+
beta_det = beta.detach()
|
| 671 |
+
v_th_det = v_th.detach()
|
| 672 |
+
v_init_det = v_init.detach() if isinstance(v_init, torch.Tensor) else v_init
|
| 673 |
+
|
| 674 |
+
spike_pattern = (V_L_det >= v_th_det).float()
|
| 675 |
+
|
| 676 |
+
for _ in range(max_iter - 1):
|
| 677 |
+
# 计算 ΔS: ΔS[k] = beta[k] * ΔS[k-1] + v_th[k] * s[k]
|
| 678 |
+
delta_S = linear_recurrence(
|
| 679 |
+
beta_det, v_th_det * spike_pattern,
|
| 680 |
+
torch.zeros_like(v_init_det) if isinstance(v_init_det, torch.Tensor)
|
| 681 |
+
else torch.zeros_like(V_L_det[0]),
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
# ΔS_prev = ΔS[k-1](位移一步)
|
| 685 |
+
delta_S_prev = torch.zeros_like(delta_S)
|
| 686 |
+
delta_S_prev[1:] = delta_S[:-1]
|
| 687 |
+
|
| 688 |
+
# V_pre = V_L - beta * ΔS_prev
|
| 689 |
+
V_pre_det = V_L_det - beta_det * delta_S_prev
|
| 690 |
+
|
| 691 |
+
# 更新 spike
|
| 692 |
+
spike_new = (V_pre_det >= v_th_det).float()
|
| 693 |
+
|
| 694 |
+
# 收敛检查
|
| 695 |
+
if torch.equal(spike_new, spike_pattern):
|
| 696 |
+
break
|
| 697 |
+
spike_pattern = spike_new
|
| 698 |
+
|
| 699 |
+
# Phase 3: 用 converged spike pattern 重算 V_post(有完整梯度)
|
| 700 |
+
# spike_pattern 是 detached 的,作为常数参与计算
|
| 701 |
+
# 梯度通过 u, v_th, beta, v_init 流动
|
| 702 |
+
u_eff = u - v_th * spike_pattern
|
| 703 |
+
V_post = linear_recurrence(beta, u_eff, v_init) # (K, *shape)
|
| 704 |
+
|
| 705 |
+
# 重建 V_pre(有梯度,用于 surrogate gradient)
|
| 706 |
+
V_post_prev = torch.zeros_like(V_post)
|
| 707 |
+
if isinstance(v_init, torch.Tensor):
|
| 708 |
+
V_post_prev[0] = v_init
|
| 709 |
+
V_post_prev[1:] = V_post[:-1]
|
| 710 |
+
V_pre = beta * V_post_prev + u
|
| 711 |
+
|
| 712 |
+
# 生成可微 spike 输出
|
| 713 |
+
if surrogate_function is not None:
|
| 714 |
+
# forward: Heaviside(V_pre - v_th), backward: surrogate gradient
|
| 715 |
+
spike = surrogate_function(V_pre - v_th)
|
| 716 |
+
else:
|
| 717 |
+
# 退化模式:硬阈值,无梯度
|
| 718 |
+
spike = (V_pre >= v_th).float()
|
| 719 |
+
|
| 720 |
+
return spike, V_post, V_pre
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
def plif_rowparam_forward(
|
| 724 |
+
beta_row: torch.Tensor,
|
| 725 |
+
u: torch.Tensor,
|
| 726 |
+
v_th_row: torch.Tensor,
|
| 727 |
+
v_init: torch.Tensor,
|
| 728 |
+
surrogate_function=None,
|
| 729 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 730 |
+
"""
|
| 731 |
+
行参数 PLIF 前向:beta 和 v_th 在 K 步中保持恒定。
|
| 732 |
+
|
| 733 |
+
比 plif_parallel_forward 快 ~40%(省去 expand+contiguous,减少 2/3 显存读取)。
|
| 734 |
+
用于 ParametricLIFNode(固定 beta/v_th)或合并多个固定参数神经元。
|
| 735 |
+
|
| 736 |
+
Args:
|
| 737 |
+
beta_row: (*shape) — 每列的衰减率(所有 K 步相同)
|
| 738 |
+
u: (K, *shape) — 每步输入
|
| 739 |
+
v_th_row: (*shape) — 每列的阈值(所有 K 步相同)
|
| 740 |
+
v_init: (*shape) — 初始膜电位
|
| 741 |
+
surrogate_function: surrogate gradient 函数
|
| 742 |
+
|
| 743 |
+
Returns:
|
| 744 |
+
spike: (K, *shape) — spike 模式
|
| 745 |
+
V_post: (K, *shape) — 发放后膜电位
|
| 746 |
+
"""
|
| 747 |
+
if (_HAS_TRITON and u.is_cuda and surrogate_function is not None
|
| 748 |
+
and hasattr(surrogate_function, 'alpha')
|
| 749 |
+
and type(surrogate_function).__name__ == 'Sigmoid'):
|
| 750 |
+
alpha = float(surrogate_function.alpha)
|
| 751 |
+
spike, V_post = _TritonPLIFRowParamForward.apply(
|
| 752 |
+
beta_row, u, v_th_row, v_init, alpha,
|
| 753 |
+
)
|
| 754 |
+
return spike, V_post
|
| 755 |
+
|
| 756 |
+
# Fallback: expand to full (K, *shape) and use standard path
|
| 757 |
+
K = u.shape[0]
|
| 758 |
+
beta = beta_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous()
|
| 759 |
+
v_th = v_th_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous()
|
| 760 |
+
spike, V_post, _ = plif_parallel_forward(beta, u, v_th, v_init, surrogate_function=surrogate_function)
|
| 761 |
+
return spike, V_post
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
def plif_fixed_param_forward(
|
| 765 |
+
beta,
|
| 766 |
+
u: torch.Tensor,
|
| 767 |
+
v_th,
|
| 768 |
+
v_init: torch.Tensor,
|
| 769 |
+
max_iter: int = 3,
|
| 770 |
+
surrogate_function=None,
|
| 771 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 772 |
+
"""
|
| 773 |
+
固定参数 PLIF 神经元的并行前向(如输出神经元、FFN 神经元)。
|
| 774 |
+
|
| 775 |
+
ParametricLIFNode 方程: V[k] = beta * V[k-1] + (1-beta) * x[k]
|
| 776 |
+
其中 beta = 1/(1+exp(w)), 可为 scalar tensor(保持梯度流向 w)。
|
| 777 |
+
|
| 778 |
+
scalar/0-dim beta 和 v_th 使用 row-param 内核(无需 expand 到 (K, *shape))。
|
| 779 |
+
|
| 780 |
+
Args:
|
| 781 |
+
beta: 衰减率 — scalar float、0-dim tensor 或 (K, *shape) tensor
|
| 782 |
+
u: (K, *shape) — 输入(已乘以 (1-beta))
|
| 783 |
+
v_th: 阈值 — scalar float、0-dim tensor 或 (K, *shape) tensor
|
| 784 |
+
v_init: (*shape) — 初始膜电位
|
| 785 |
+
max_iter: spike 迭代次数
|
| 786 |
+
surrogate_function: surrogate gradient 函数
|
| 787 |
+
|
| 788 |
+
Returns:
|
| 789 |
+
spike: (K, *shape) — spike 模式
|
| 790 |
+
V_post: (K, *shape) — 发放后膜电位
|
| 791 |
+
"""
|
| 792 |
+
K = u.shape[0]
|
| 793 |
+
shape = u.shape[1:]
|
| 794 |
+
|
| 795 |
+
# Row-param fast path: beta 和 v_th 都是 scalar/0-dim → 扩展为 (*shape) 行向量
|
| 796 |
+
beta_is_scalar = isinstance(beta, torch.Tensor) and beta.dim() == 0
|
| 797 |
+
beta_is_float = not isinstance(beta, torch.Tensor)
|
| 798 |
+
vth_is_scalar = isinstance(v_th, torch.Tensor) and v_th.dim() == 0
|
| 799 |
+
vth_is_float = not isinstance(v_th, torch.Tensor)
|
| 800 |
+
|
| 801 |
+
if (beta_is_scalar or beta_is_float) and (vth_is_scalar or vth_is_float):
|
| 802 |
+
# Build row vectors (*shape)
|
| 803 |
+
if beta_is_scalar:
|
| 804 |
+
beta_row = beta.expand(*shape).contiguous()
|
| 805 |
+
else:
|
| 806 |
+
beta_row = torch.full(shape, beta, device=u.device, dtype=u.dtype)
|
| 807 |
+
if vth_is_scalar:
|
| 808 |
+
v_th_row = v_th.expand(*shape).contiguous()
|
| 809 |
+
else:
|
| 810 |
+
v_th_row = torch.full(shape, v_th, device=u.device, dtype=u.dtype)
|
| 811 |
+
return plif_rowparam_forward(beta_row, u, v_th_row, v_init, surrogate_function)
|
| 812 |
+
|
| 813 |
+
# Full-tensor path: expand to (K, *shape) if needed
|
| 814 |
+
if isinstance(beta, torch.Tensor):
|
| 815 |
+
if beta.dim() == 0:
|
| 816 |
+
beta = beta.expand(K, *shape).contiguous()
|
| 817 |
+
else:
|
| 818 |
+
beta = torch.full_like(u, beta)
|
| 819 |
+
|
| 820 |
+
if isinstance(v_th, torch.Tensor):
|
| 821 |
+
if v_th.dim() == 0:
|
| 822 |
+
v_th = v_th.expand(K, *shape).contiguous()
|
| 823 |
+
else:
|
| 824 |
+
v_th = torch.full_like(u, v_th)
|
| 825 |
+
|
| 826 |
+
spike, V_post, _ = plif_parallel_forward(
|
| 827 |
+
beta, u, v_th, v_init, max_iter, surrogate_function,
|
| 828 |
+
)
|
| 829 |
+
return spike, V_post
|
atomic_ops/plif_node.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PLIFNode: D 维固定参数 PLIF 神经元(设计文档 5.5 "普通 SNN 神经元")
|
| 3 |
+
|
| 4 |
+
与 SelectivePLIFNode 的区别:
|
| 5 |
+
SelectivePLIF: β(t), α(t), V_th(t) 由输入每步动态计算(选择性记忆)
|
| 6 |
+
PLIFNode: β, V_th 为 D 维可学习参数,训练后固定(信号转换)
|
| 7 |
+
|
| 8 |
+
每个维度有独立的可学习参数:
|
| 9 |
+
β_d = sigmoid(w_d): 时间常数(衰减率)
|
| 10 |
+
V_th_d: 发放阈值
|
| 11 |
+
|
| 12 |
+
动力学(与 ParametricLIF 一致):
|
| 13 |
+
V[t] = β · V[t-1] + (1-β) · x[t]
|
| 14 |
+
s[t] = Θ(V[t] - V_th) (surrogate gradient)
|
| 15 |
+
V[t] -= V_th · s[t] (soft reset)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from spikingjelly.activation_based import base, surrogate
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PLIFNode(base.MemoryModule):
|
| 26 |
+
"""
|
| 27 |
+
D 维固定参数 PLIF 神经元。
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dim: 神经元数量(每个维度独立参数)
|
| 31 |
+
init_tau: 初始时间常数 τ(β = 1 - 1/τ)
|
| 32 |
+
v_threshold: 初始发放阈值
|
| 33 |
+
surrogate_function: surrogate gradient 函数
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
dim: int,
|
| 39 |
+
init_tau: float = 2.0,
|
| 40 |
+
v_threshold: float = 0.5,
|
| 41 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
# D 维可学习参数(随机初始化,每个维度独立)
|
| 45 |
+
# w: 控制 β=sigmoid(w),随机产生不同时间常数
|
| 46 |
+
# init_w ± 0.5 → β ∈ ~[sigmoid(w-0.5), sigmoid(w+0.5)]
|
| 47 |
+
# tau=2.0 时 w=0, β ∈ ~[0.38, 0.62]
|
| 48 |
+
init_w = -math.log(init_tau - 1.0)
|
| 49 |
+
self.w = nn.Parameter(torch.empty(dim).normal_(init_w, 0.5))
|
| 50 |
+
# v_th: 发放阈值,U[0.5x, 1.5x] 均匀分布产生维度间多样性
|
| 51 |
+
self.v_th = nn.Parameter(torch.empty(dim).uniform_(
|
| 52 |
+
v_threshold * 0.5, v_threshold * 1.5,
|
| 53 |
+
))
|
| 54 |
+
self.surrogate_function = surrogate_function
|
| 55 |
+
# 膜电位状态(functional.reset_net 时重置为 0.)
|
| 56 |
+
self.register_memory('v', 0.)
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def beta(self):
|
| 60 |
+
"""D 维衰减率 β = sigmoid(w),值域 (0, 1)。"""
|
| 61 |
+
return torch.sigmoid(self.w)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
"""
|
| 65 |
+
单步前向传播。
|
| 66 |
+
|
| 67 |
+
V[t] = β · V[t-1] + (1-β) · x[t], spike = Θ(V-V_th), soft reset。
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
x: 输入电流, shape (batch, dim)
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
spike: 二值脉冲, shape (batch, dim), 值域 {0, 1}
|
| 74 |
+
"""
|
| 75 |
+
if isinstance(self.v, float):
|
| 76 |
+
self.v = torch.zeros_like(x)
|
| 77 |
+
beta = self.beta
|
| 78 |
+
self.v = beta * self.v + (1.0 - beta) * x
|
| 79 |
+
spike = self.surrogate_function(self.v - self.v_th)
|
| 80 |
+
self.v = self.v - spike * self.v_th # soft reset
|
| 81 |
+
return spike
|
atomic_ops/rms_norm.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RMSNorm: 残差流分支归一化(Pre-LN 模式)。
|
| 3 |
+
|
| 4 |
+
位置: h → RMSNorm → PLIFNode → SNN子层 → out_proj → 残差
|
| 5 |
+
作用: 控制送入 PLIFNode 的输入 scale,防止残差流漂移/爆炸。
|
| 6 |
+
仅归一化分支输入,残差流本身不被归一化。
|
| 7 |
+
|
| 8 |
+
对标 Qwen3/LLaMA 的 Pre-LN RMSNorm。
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RMSNorm(nn.Module):
|
| 16 |
+
"""Root Mean Square Layer Normalization.
|
| 17 |
+
|
| 18 |
+
x_norm = x / RMS(x) * weight
|
| 19 |
+
RMS(x) = sqrt(mean(x^2) + eps)
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
dim: 归一化维度
|
| 23 |
+
eps: 数值稳定性
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 29 |
+
self.eps = eps
|
| 30 |
+
|
| 31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
input_dtype = x.dtype
|
| 33 |
+
x = x.float()
|
| 34 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 35 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 36 |
+
return (self.weight * x).to(input_dtype)
|
atomic_ops/selective_plif.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SelectivePLIFNode: 动态参数的 PLIF 神经元
|
| 3 |
+
|
| 4 |
+
与标准 ParametricLIFNode 的区别:
|
| 5 |
+
- β(t), α(t), V_th(t) 作为外部参数每步传入,不是内部 nn.Parameter
|
| 6 |
+
- 本神经元无可训练参数,所有学习发生在 SNNBlock 的调制网络中
|
| 7 |
+
- 仅支持 step_mode='s'(单步模式)
|
| 8 |
+
- 仅支持 soft reset(v_reset=None)
|
| 9 |
+
|
| 10 |
+
状态方程:
|
| 11 |
+
V[t] = β(t) · V[t-1] + α(t) · I[t]
|
| 12 |
+
s[t] = Θ(V[t] - V_th(t)) (surrogate gradient)
|
| 13 |
+
V[t] -= V_th(t) · s[t] (soft reset)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from spikingjelly.activation_based import neuron, surrogate
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SelectivePLIFNode(neuron.BaseNode):
|
| 21 |
+
"""
|
| 22 |
+
隐状态空间的核心神经元。
|
| 23 |
+
|
| 24 |
+
接收外部动态计算的 β(t), α(t), V_th(t),执行:
|
| 25 |
+
charge → fire → soft reset
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
surrogate_function: surrogate gradient 函数,默认 Sigmoid(alpha=4.0)
|
| 29 |
+
detach_reset: 是否在 reset 时 detach spike,默认 False
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 35 |
+
detach_reset: bool = False,
|
| 36 |
+
):
|
| 37 |
+
# v_threshold=1.0 是占位值,实际使用外部传入的 v_th
|
| 38 |
+
# v_reset=None 触发 soft reset 模式,register_memory('v', 0.)
|
| 39 |
+
super().__init__(
|
| 40 |
+
v_threshold=1.0,
|
| 41 |
+
v_reset=None,
|
| 42 |
+
surrogate_function=surrogate_function,
|
| 43 |
+
detach_reset=detach_reset,
|
| 44 |
+
step_mode='s',
|
| 45 |
+
backend='torch',
|
| 46 |
+
store_v_seq=False,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def single_step_forward(
|
| 50 |
+
self,
|
| 51 |
+
x: torch.Tensor,
|
| 52 |
+
beta: torch.Tensor,
|
| 53 |
+
alpha: torch.Tensor,
|
| 54 |
+
v_th: torch.Tensor,
|
| 55 |
+
) -> torch.Tensor:
|
| 56 |
+
"""
|
| 57 |
+
单步前向传播。
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
x: 输入电流 I[t], shape (batch, D*N)
|
| 61 |
+
beta: 衰减率 β(t), shape (batch, D*N), 值域 (0, 1)
|
| 62 |
+
alpha: 写入增益 α(t), shape (batch, D*N), 值域 R+
|
| 63 |
+
v_th: 动态阈值 V_th(t), shape (batch, D*N), 值域 R+
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
spike: 二值脉冲 s[t], shape (batch, D*N), 值域 {0, 1}
|
| 67 |
+
"""
|
| 68 |
+
# Phase 0: 首步将 v 从 float 扩展为与输入同形的张量
|
| 69 |
+
self.v_float_to_tensor(x)
|
| 70 |
+
|
| 71 |
+
# Phase 1: Charge — 膜电位更新
|
| 72 |
+
# V[t] = β(t) · V[t-1] + α(t) · I[t]
|
| 73 |
+
self.v = beta * self.v + alpha * x
|
| 74 |
+
|
| 75 |
+
# Phase 2: Fire — 使用动态 v_th(不是 self.v_threshold)
|
| 76 |
+
# spike = Heaviside(V[t] - V_th(t)),反向用 surrogate gradient
|
| 77 |
+
spike = self.surrogate_function(self.v - v_th)
|
| 78 |
+
|
| 79 |
+
# Phase 3: Soft Reset — V[t] -= V_th(t) · s[t]
|
| 80 |
+
if self.detach_reset:
|
| 81 |
+
spike_d = spike.detach()
|
| 82 |
+
else:
|
| 83 |
+
spike_d = spike
|
| 84 |
+
self.v = self.v - spike_d * v_th
|
| 85 |
+
|
| 86 |
+
return spike
|
| 87 |
+
|
| 88 |
+
def extra_repr(self) -> str:
|
| 89 |
+
return (
|
| 90 |
+
f'v_reset={self.v_reset}, '
|
| 91 |
+
f'detach_reset={self.detach_reset}, '
|
| 92 |
+
f'step_mode={self.step_mode}, '
|
| 93 |
+
f'surrogate={self.surrogate_function}'
|
| 94 |
+
)
|
atomic_ops/snn_block.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SNNBlock: 完整的 SNN 隐状态空间 Block(并行化版本)
|
| 3 |
+
|
| 4 |
+
结构(每个 SNN 时间步):
|
| 5 |
+
spike_in {0,1}^D
|
| 6 |
+
├─→ W_in → I[t] ∈ R^{D*N}
|
| 7 |
+
├─→ W_β^(x) + b_β → σ → β(t)
|
| 8 |
+
├─→ W_α^(x) + b_α → softplus → α(t)
|
| 9 |
+
├─→ W_th^(x) + b_th → |·|+V_min → V_th(t)
|
| 10 |
+
├─→ W_gate → sigmoid → gate ∈ (0,1)^D
|
| 11 |
+
└─→ W_skip → I_skip ∈ R^D
|
| 12 |
+
|
| 13 |
+
SelectivePLIF(I, β, α, V_th) → s[t] ∈ {0,1}^{D*N}
|
| 14 |
+
|
| 15 |
+
W_out · V_post[t] ⊙ gate + I_skip → 连续输出 ∈ R^D
|
| 16 |
+
|
| 17 |
+
数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from spikingjelly.activation_based import base, layer, surrogate
|
| 25 |
+
|
| 26 |
+
from .selective_plif import SelectivePLIFNode
|
| 27 |
+
from .parallel_scan import plif_parallel_forward
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ====== Fused modulation activations (torch.compile) ======
|
| 31 |
+
# Fuse sigmoid + softplus + abs + alpha*I into single kernel.
|
| 32 |
+
# 7-8 separate element-wise kernels → 1 fused kernel, ~4x speedup on DN-sized tensors.
|
| 33 |
+
# First call triggers JIT compilation (~seconds); cached for subsequent calls.
|
| 34 |
+
|
| 35 |
+
@torch.compile(backend='inductor', fullgraph=True)
|
| 36 |
+
def _fused_modulation(raw_beta, b_beta, raw_alpha, b_alpha, raw_th, b_th, v_th_min, I_all):
|
| 37 |
+
beta = torch.sigmoid(raw_beta + b_beta)
|
| 38 |
+
alpha = F.softplus(raw_alpha + b_alpha)
|
| 39 |
+
v_th = v_th_min + torch.abs(raw_th + b_th)
|
| 40 |
+
u = alpha * I_all
|
| 41 |
+
return beta, u, v_th
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SNNBlock(base.MemoryModule):
|
| 45 |
+
"""
|
| 46 |
+
单个 SNN Block(并行化)。
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
D: 可见维度(Block 间通信的维度)
|
| 50 |
+
N: 状态扩展因子(每个通道的隐神经元数)
|
| 51 |
+
v_th_min: 动态阈值下限
|
| 52 |
+
surrogate_function: surrogate gradient 函数
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
D: int,
|
| 58 |
+
N: int = 8,
|
| 59 |
+
v_th_min: float = 0.1,
|
| 60 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 61 |
+
):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.D = D
|
| 64 |
+
self.N = N
|
| 65 |
+
self.v_th_min = v_th_min
|
| 66 |
+
DN = D * N
|
| 67 |
+
|
| 68 |
+
# ====== 六条并行输入投影(SNN 突触:spike 输入) ======
|
| 69 |
+
self.W_in = layer.Linear(D, DN, bias=False, step_mode='s')
|
| 70 |
+
self.W_beta_x = layer.Linear(D, DN, bias=False, step_mode='s')
|
| 71 |
+
self.W_alpha_x = layer.Linear(D, DN, bias=False, step_mode='s')
|
| 72 |
+
self.W_th_x = layer.Linear(D, DN, bias=False, step_mode='s')
|
| 73 |
+
self.W_gate = layer.Linear(D, D, bias=False, step_mode='s')
|
| 74 |
+
self.W_skip = layer.Linear(D, D, bias=False, step_mode='s')
|
| 75 |
+
|
| 76 |
+
# ====== β/α/V_th 仅依赖 spike_in(无 W^(V)·V 项) ======
|
| 77 |
+
|
| 78 |
+
# ====== 调制偏置(结构化初始化) ======
|
| 79 |
+
self.b_beta = nn.Parameter(torch.empty(DN))
|
| 80 |
+
self.b_alpha = nn.Parameter(torch.empty(DN))
|
| 81 |
+
self.b_th = nn.Parameter(torch.empty(DN))
|
| 82 |
+
|
| 83 |
+
# ====== 输出投影:D*N → D(SNN 突触) ======
|
| 84 |
+
self.W_out = layer.Linear(DN, D, bias=False, step_mode='s')
|
| 85 |
+
|
| 86 |
+
# ====== 隐状态空间神经元(D*N 个,动态参数) ======
|
| 87 |
+
self.hidden_neuron = SelectivePLIFNode(
|
| 88 |
+
surrogate_function=surrogate_function,
|
| 89 |
+
detach_reset=False,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# ====== 参数初始化 ======
|
| 93 |
+
self._initialize_parameters()
|
| 94 |
+
|
| 95 |
+
def _initialize_parameters(self):
|
| 96 |
+
"""功能引导初始化。"""
|
| 97 |
+
D, N = self.D, self.N
|
| 98 |
+
K_ref = 16
|
| 99 |
+
|
| 100 |
+
# 目标 β 分布:多时间尺度 [0.80, 0.99]
|
| 101 |
+
beta_values = torch.linspace(0.80, 0.99, N)
|
| 102 |
+
|
| 103 |
+
# ====== 1. β 偏置:logit-spaced + 维度间随机扰动 ======
|
| 104 |
+
b_beta_per_n = torch.log(beta_values / (1.0 - beta_values))
|
| 105 |
+
# 以 per_n 值为均值,加 N(0, 0.1) 扰动打破 D 个通道的对称性
|
| 106 |
+
self.b_beta.data.copy_(b_beta_per_n.repeat(D))
|
| 107 |
+
self.b_beta.data.add_(torch.empty_like(self.b_beta).normal_(0, 0.1))
|
| 108 |
+
|
| 109 |
+
# ====== 2. α 偏置:softplus(0.5413) ≈ 1.0 + 维度间随机扰动 ======
|
| 110 |
+
# 以 0.5413 为均值,N(0, 0.1) 扰动 → α ∈ ~[0.7, 1.3]
|
| 111 |
+
self.b_alpha.data.normal_(0.5413, 0.1)
|
| 112 |
+
|
| 113 |
+
# ====== 3. W^(x) 权重 ======
|
| 114 |
+
for lin in [self.W_in, self.W_gate, self.W_skip, self.W_out]:
|
| 115 |
+
nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
|
| 116 |
+
for lin in [self.W_beta_x, self.W_alpha_x, self.W_th_x]:
|
| 117 |
+
nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
|
| 118 |
+
lin.weight.data.mul_(0.1)
|
| 119 |
+
|
| 120 |
+
# ====== 4. W_in 时间尺度缩放 ======
|
| 121 |
+
scale_per_n = torch.sqrt(1.0 - beta_values ** 2) # (N,)
|
| 122 |
+
scale_DN = scale_per_n.repeat(D) # (D*N,)
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
self.W_in.weight.mul_(scale_DN.unsqueeze(1))
|
| 125 |
+
|
| 126 |
+
# ====== 5. b_th:σ_V 校准 ======
|
| 127 |
+
# σ_V = sqrt(p/3) * sqrt(1 - β^{2K})
|
| 128 |
+
# 其中 p 是输入 firing rate。旧版假设 p=0.5(σ_I=0.408),
|
| 129 |
+
# 但实际 input_neuron firing rate 约 0.07~0.45,深层更低。
|
| 130 |
+
# 用 p=0.15 保守估计,避免 v_th 过高导致死神经元。
|
| 131 |
+
p_assumed = 0.15
|
| 132 |
+
sigma_I_base = math.sqrt(p_assumed / 3.0)
|
| 133 |
+
sigma_V_per_n = sigma_I_base * torch.sqrt(
|
| 134 |
+
1.0 - beta_values ** (2 * K_ref)
|
| 135 |
+
)
|
| 136 |
+
target_p_fire = torch.linspace(0.25, 0.08, N)
|
| 137 |
+
z_scores = math.sqrt(2.0) * torch.erfinv(
|
| 138 |
+
2.0 * (1.0 - target_p_fire) - 1.0
|
| 139 |
+
)
|
| 140 |
+
target_V_th = sigma_V_per_n * z_scores
|
| 141 |
+
b_th_per_n = torch.clamp(target_V_th - self.v_th_min, min=0.05)
|
| 142 |
+
# 以 per_n 值为均值,加 N(0, 0.02) 扰动打破 D 个通道的对称性
|
| 143 |
+
self.b_th.data.copy_(b_th_per_n.repeat(D))
|
| 144 |
+
self.b_th.data.add_(torch.empty_like(self.b_th).normal_(0, 0.02))
|
| 145 |
+
|
| 146 |
+
# ====== 6. W_out 发放率均衡缩放 ======
|
| 147 |
+
out_scale_per_n = 1.0 / torch.sqrt(target_p_fire)
|
| 148 |
+
out_scale_per_n = out_scale_per_n / out_scale_per_n.mean()
|
| 149 |
+
out_scale_DN = out_scale_per_n.repeat(D)
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
self.W_out.weight.mul_(out_scale_DN.unsqueeze(0))
|
| 152 |
+
|
| 153 |
+
def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
"""
|
| 155 |
+
并行前向传播:使用 parallel scan 处理全序列。
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出(V_post 经 W_out 投影)
|
| 162 |
+
"""
|
| 163 |
+
TK, batch, D = spike_in_seq.shape
|
| 164 |
+
DN = self.D * self.N
|
| 165 |
+
|
| 166 |
+
# ====== Phase 1: 批量投影(全部 TK 帧同时计算)======
|
| 167 |
+
flat = spike_in_seq.reshape(TK * batch, D)
|
| 168 |
+
|
| 169 |
+
I_all = F.linear(flat, self.W_in.weight).reshape(TK, batch, DN)
|
| 170 |
+
raw_beta = F.linear(flat, self.W_beta_x.weight).reshape(TK, batch, DN)
|
| 171 |
+
raw_alpha = F.linear(flat, self.W_alpha_x.weight).reshape(TK, batch, DN)
|
| 172 |
+
raw_th = F.linear(flat, self.W_th_x.weight).reshape(TK, batch, DN)
|
| 173 |
+
gate_all = torch.sigmoid(
|
| 174 |
+
F.linear(flat, self.W_gate.weight).reshape(TK, batch, D)
|
| 175 |
+
)
|
| 176 |
+
I_skip_all = F.linear(flat, self.W_skip.weight).reshape(TK, batch, D)
|
| 177 |
+
|
| 178 |
+
# ====== Phase 1b: 融合激活(torch.compile → 单 kernel)======
|
| 179 |
+
beta_all, u_hidden, v_th_all = _fused_modulation(
|
| 180 |
+
raw_beta, self.b_beta, raw_alpha, self.b_alpha,
|
| 181 |
+
raw_th, self.b_th, self.v_th_min, I_all,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# 获取隐神经元初始状态
|
| 185 |
+
v_init_hidden = self.hidden_neuron.v
|
| 186 |
+
if isinstance(v_init_hidden, float):
|
| 187 |
+
v_init_hidden = torch.zeros(batch, DN, device=flat.device, dtype=flat.dtype)
|
| 188 |
+
|
| 189 |
+
s_hidden, V_post_hidden, _ = plif_parallel_forward(
|
| 190 |
+
beta_all, u_hidden, v_th_all, v_init_hidden, max_iter=3,
|
| 191 |
+
surrogate_function=self.hidden_neuron.surrogate_function,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# 更新隐神经元状态(保存末步供下次调用)
|
| 195 |
+
self.hidden_neuron.v = V_post_hidden[-1].detach()
|
| 196 |
+
|
| 197 |
+
# ====== Phase 4: 输出投影(V_post → W_out: 连续梯度直通 β)======
|
| 198 |
+
# 用 V_post(膜电压)代替 spike 作为 W_out 输入,消除 surrogate 梯度瓶颈:
|
| 199 |
+
# spike 路径: ∂spike/∂β = surrogate'(V-v_th) · V_prev ≈ 0(大部分时刻)
|
| 200 |
+
# V_post 路径: ∂V_post/∂β = V_prev(无 surrogate 阻断,每步都有梯度)
|
| 201 |
+
v_flat = V_post_hidden.reshape(TK * batch, DN)
|
| 202 |
+
I_out_all = F.linear(v_flat, self.W_out.weight).reshape(TK, batch, D)
|
| 203 |
+
I_total_all = I_out_all * gate_all + I_skip_all # (TK, batch, D)
|
| 204 |
+
|
| 205 |
+
# output_neuron 已移除:连续值由层级 K 帧聚合处理
|
| 206 |
+
return I_total_all # (TK, batch, D), 连续值
|
| 207 |
+
|
| 208 |
+
def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor:
|
| 209 |
+
"""
|
| 210 |
+
单步前向传播(用于调试/兼容)。
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1}
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
continuous_out: 连续输出, shape (batch, D)
|
| 217 |
+
"""
|
| 218 |
+
V_prev = self.hidden_neuron.v
|
| 219 |
+
if isinstance(V_prev, float):
|
| 220 |
+
V_prev = torch.zeros(
|
| 221 |
+
spike_in.shape[0], self.D * self.N,
|
| 222 |
+
device=spike_in.device, dtype=spike_in.dtype,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
I_t = self.W_in(spike_in)
|
| 226 |
+
|
| 227 |
+
# β 调制仅依赖 spike_in
|
| 228 |
+
beta = torch.sigmoid(self.W_beta_x(spike_in) + self.b_beta)
|
| 229 |
+
alpha = F.softplus(self.W_alpha_x(spike_in) + self.b_alpha)
|
| 230 |
+
v_th = self.v_th_min + torch.abs(self.W_th_x(spike_in) + self.b_th)
|
| 231 |
+
|
| 232 |
+
gate = torch.sigmoid(self.W_gate(spike_in))
|
| 233 |
+
I_skip = self.W_skip(spike_in)
|
| 234 |
+
|
| 235 |
+
s_hidden = self.hidden_neuron(I_t, beta, alpha, v_th)
|
| 236 |
+
|
| 237 |
+
# 用 V_post(膜电压)做输出投影,与 forward_parallel 一致
|
| 238 |
+
V_post = self.hidden_neuron.v # 发放+重置后的膜电位
|
| 239 |
+
I_out = self.W_out(V_post)
|
| 240 |
+
I_total = I_out * gate + I_skip
|
| 241 |
+
|
| 242 |
+
return I_total # 连续值
|
atomic_ops/snn_decoder_layer.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SNNDecoderLayer: 单个 SNN 解码层(Pre-LN 连续残差流 + 动态 K 帧聚合)
|
| 3 |
+
|
| 4 |
+
RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → 残差
|
| 5 |
+
RMSNorm → PLIF → SNNFFN → 动态K聚合 → out_proj → 残差
|
| 6 |
+
|
| 7 |
+
动态 K:
|
| 8 |
+
- K 是最大步数(K_max),不是固定步数。不同 token 有效步数 ∈ [1, K_max]。
|
| 9 |
+
- 每个 token 的 K 帧 SNN 输出,学习自适应停止概率 p_halt
|
| 10 |
+
- PonderNet 几何分布加权:λ_k = p_k · ∏_{j<k}(1-p_j),归一化后加权聚合
|
| 11 |
+
- 不同 token 有效步数不同:简单 token 早停(E[K]小),复杂 token 用满步数
|
| 12 |
+
- ponder_cost 正则化:鼓励用更少步数完成简单 token 的处理
|
| 13 |
+
|
| 14 |
+
数学推导:
|
| 15 |
+
停止概率: p_k = σ(halt_proj(frame_k)) ∈ (0,1)
|
| 16 |
+
生存概率: S_k = ∏_{j=1}^{k-1} (1 - p_j) — 到第 k 步还没停
|
| 17 |
+
权重: λ_k = p_k · S_k — 恰好在第 k 步停止的概率
|
| 18 |
+
归一化: λ̂_k = λ_k / Σ_k λ_k — 确保权重和为 1
|
| 19 |
+
聚合: output = Σ_k λ̂_k · frame_k
|
| 20 |
+
代价: E[K] = Σ_k k · λ̂_k — 期望步数
|
| 21 |
+
|
| 22 |
+
K_max 设计原则:
|
| 23 |
+
K_max 越大,模型对复杂 token 的处理能力越强(更多步数可用),
|
| 24 |
+
但计算量和显存线性增长。K_max=32 允许 token 使用 1~32 步。
|
| 25 |
+
PonderNet 的 ponder_cost 正则化确保简单 token 不浪费步数。
|
| 26 |
+
|
| 27 |
+
K 帧层间聚合:
|
| 28 |
+
- SNN 子层输出 K 帧连续值(V_post 经投影),PonderNet 加权聚合为 1 per token
|
| 29 |
+
- 聚合后经 out_proj 投影,广播回 K 帧做残差
|
| 30 |
+
- 使 β 的时间动力学通过 K 帧聚合梯度有效传播
|
| 31 |
+
|
| 32 |
+
对标 Qwen3DecoderLayer(Pre-LN 模式完全等价):
|
| 33 |
+
Qwen3: RMSNorm → Attention → residual → RMSNorm → MLP → residual
|
| 34 |
+
SNN: RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → residual
|
| 35 |
+
→ RMSNorm → PLIF → SNNFFN → 动态K聚合 → out_proj → residual
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import math
|
| 39 |
+
|
| 40 |
+
import torch
|
| 41 |
+
import torch.nn as nn
|
| 42 |
+
import torch.nn.functional as F
|
| 43 |
+
from spikingjelly.activation_based import base, surrogate
|
| 44 |
+
|
| 45 |
+
from .plif_node import PLIFNode
|
| 46 |
+
from .rms_norm import RMSNorm
|
| 47 |
+
from .snn_block import SNNBlock
|
| 48 |
+
from .snn_ffn import SNNFFN
|
| 49 |
+
from .parallel_scan import plif_rowparam_forward
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ====== Fused halt weight computation (torch.compile) ======
|
| 53 |
+
# 7-8 个独立 element-wise kernel → 单 fused kernel
|
| 54 |
+
# sigmoid + clamp + log1p + cumsum + exp + normalize
|
| 55 |
+
# 首次调用触发 JIT 编译(~秒级),后续调用走缓存
|
| 56 |
+
|
| 57 |
+
@torch.compile(backend='inductor', fullgraph=True)
|
| 58 |
+
def _fused_geometric_halt(halt_logits):
|
| 59 |
+
"""融合计算 PonderNet 几何分布停止权重。
|
| 60 |
+
|
| 61 |
+
输入: halt_logits (seq_len, K, batch) — halt_proj 的原始输出
|
| 62 |
+
输出: halt_weights (seq_len, K, batch) — 归一化几何分布权重,sum=1
|
| 63 |
+
|
| 64 |
+
数学: p_k = σ(logit_k), S_k = ∏_{j<k}(1-p_j), λ_k = p_k·S_k, λ̂_k = λ_k/Σλ
|
| 65 |
+
"""
|
| 66 |
+
p_halt = torch.sigmoid(halt_logits).clamp(min=1e-7, max=1.0 - 1e-7)
|
| 67 |
+
log_1_minus_p = torch.log1p(-p_halt) # (seq_len, K, batch)
|
| 68 |
+
# Exclusive cumsum: log_survive[:, k, :] = Σ_{j<k} log(1-p_j)
|
| 69 |
+
# 避免 torch.cat: 用 cumsum([:, :-1]) 填充 [:, 1:]
|
| 70 |
+
log_survive = torch.zeros_like(log_1_minus_p)
|
| 71 |
+
log_survive[:, 1:, :] = torch.cumsum(log_1_minus_p[:, :-1, :], dim=1)
|
| 72 |
+
survive = torch.exp(log_survive) # (seq_len, K, batch)
|
| 73 |
+
halt_weights = p_halt * survive # λ_k = p_k · S_k
|
| 74 |
+
halt_weights = halt_weights / (halt_weights.sum(dim=1, keepdim=True) + 1e-8)
|
| 75 |
+
return halt_weights
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SNNDecoderLayer(base.MemoryModule):
|
| 79 |
+
"""
|
| 80 |
+
单个 SNN 解码层(连续残差流 + K 帧聚合版本)。
|
| 81 |
+
|
| 82 |
+
层间传递连续值 h (TK, batch, D),通过 PLIF 神经元转换为 spike,
|
| 83 |
+
输入 SNN 子层处理后,K 帧聚合为 1 per token,经 out_proj 投影,
|
| 84 |
+
广播回 K 帧做残差连接。
|
| 85 |
+
|
| 86 |
+
K 帧聚合使 β 的时间动力学(控制 K 步内的膜电位演化)产生可微分的
|
| 87 |
+
token 级效应,解决 β 梯度为纯噪声的问题。
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
D: 可见维度
|
| 91 |
+
N: 状态扩展因子
|
| 92 |
+
D_ff: FFN 中间层维度
|
| 93 |
+
v_th_min: SNNBlock 动态阈值下限
|
| 94 |
+
ffn_v_threshold: SNNFFN gate/up 神经元阈值
|
| 95 |
+
K: 每 token 的 SNN 时间步数
|
| 96 |
+
num_layers: 总层数(用于残差输出缩放 + SNNFFN down_proj 缩放)
|
| 97 |
+
layer_idx: 当前层索引
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
D: int,
|
| 103 |
+
N: int,
|
| 104 |
+
D_ff: int,
|
| 105 |
+
v_th_min: float,
|
| 106 |
+
ffn_v_threshold: float,
|
| 107 |
+
K: int = 16,
|
| 108 |
+
num_layers: int = 1,
|
| 109 |
+
layer_idx: int = 0,
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.D = D
|
| 113 |
+
self.K = K
|
| 114 |
+
|
| 115 |
+
self.snn_block = SNNBlock(
|
| 116 |
+
D=D, N=N, v_th_min=v_th_min,
|
| 117 |
+
)
|
| 118 |
+
self.snn_ffn = SNNFFN(
|
| 119 |
+
D=D, D_ff=D_ff,
|
| 120 |
+
output_v_threshold=ffn_v_threshold,
|
| 121 |
+
num_layers=num_layers,
|
| 122 |
+
layer_idx=layer_idx,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Pre-LN 分支归一化: h → RMSNorm → PLIFNode
|
| 126 |
+
self.block_norm = RMSNorm(D)
|
| 127 |
+
self.ffn_norm = RMSNorm(D)
|
| 128 |
+
|
| 129 |
+
# 输入神经元: RMSNorm(h) → V_post 膜电位激活(D 维可学习 β 和 V_th)
|
| 130 |
+
self.input_neuron1 = PLIFNode(
|
| 131 |
+
dim=D,
|
| 132 |
+
init_tau=2.0,
|
| 133 |
+
v_threshold=0.5,
|
| 134 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 135 |
+
)
|
| 136 |
+
self.input_neuron2 = PLIFNode(
|
| 137 |
+
dim=D,
|
| 138 |
+
init_tau=2.0,
|
| 139 |
+
v_threshold=0.5,
|
| 140 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# 输出投影(突触): spike (D) → 连续空间 (D)
|
| 144 |
+
self.block_out_proj = nn.Linear(D, D, bias=False)
|
| 145 |
+
self.ffn_out_proj = nn.Linear(D, D, bias=False)
|
| 146 |
+
|
| 147 |
+
# ====== 动态 K: 停止投影(突触: SNN 输出 → 停止概率) ======
|
| 148 |
+
# halt_proj: D → 1,每步每 token 产生一个停止 logit
|
| 149 |
+
# PonderNet 几何分布加权,替代 uniform mean 聚合
|
| 150 |
+
self.block_halt = nn.Linear(D, 1, bias=True)
|
| 151 |
+
self.ffn_halt = nn.Linear(D, 1, bias=True)
|
| 152 |
+
|
| 153 |
+
# 残差输出缩放初始化(GPT-2 style: σ = 0.02 / √(2·num_layers))
|
| 154 |
+
std = 0.02 / math.sqrt(2 * num_layers)
|
| 155 |
+
nn.init.normal_(self.block_out_proj.weight, std=std)
|
| 156 |
+
nn.init.normal_(self.ffn_out_proj.weight, std=std)
|
| 157 |
+
|
| 158 |
+
# halt 初始化: 小权重 + 负偏置 → p_halt ≈ 0.03 → 接近 uniform 聚合
|
| 159 |
+
# σ(-3.5) ≈ 0.029, 几何分布归一化后 λ_1/λ_K ≈ 1.5, 接近均匀
|
| 160 |
+
for halt in [self.block_halt, self.ffn_halt]:
|
| 161 |
+
nn.init.xavier_uniform_(halt.weight)
|
| 162 |
+
halt.weight.data.mul_(0.01)
|
| 163 |
+
nn.init.constant_(halt.bias, -3.5)
|
| 164 |
+
|
| 165 |
+
def _input_neuron_parallel(self, input_neuron, x):
|
| 166 |
+
"""
|
| 167 |
+
输入 PLIF 神经元的 parallel scan 前向传播。
|
| 168 |
+
|
| 169 |
+
完整 PLIF 动力学: V[t] = β·V[t-1] + (1-β)·x[t], spike = Θ(V-V_th), 软重置。
|
| 170 |
+
输出膜电位泄漏量 (1-β)·V_post 作为激活值——即每步因指数衰减将泄漏的量。
|
| 171 |
+
相比直接传递 V_post,泄漏量自然强调快响应神经元(大 1-β),
|
| 172 |
+
抑制慢记忆神经元(小 1-β),实现隐式的时间尺度加权。
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
input_neuron: PLIFNode 实例(D 维可学习 β 和 V_th)
|
| 176 |
+
x: (TK, batch, D) — 连续值输入
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
leak: (TK, batch, D) — 膜电位泄漏量 (1-β)·V_post
|
| 180 |
+
"""
|
| 181 |
+
TK, batch, D = x.shape
|
| 182 |
+
|
| 183 |
+
beta = input_neuron.beta # (D,)
|
| 184 |
+
u = (1.0 - beta) * x # (D,) broadcast → (TK, batch, D)
|
| 185 |
+
|
| 186 |
+
v_init = input_neuron.v
|
| 187 |
+
if isinstance(v_init, float):
|
| 188 |
+
v_init = torch.zeros(batch, D, device=x.device, dtype=x.dtype)
|
| 189 |
+
|
| 190 |
+
beta_row = beta.unsqueeze(0).expand(batch, D).contiguous()
|
| 191 |
+
v_th_row = input_neuron.v_th.unsqueeze(0).expand(batch, D).contiguous()
|
| 192 |
+
|
| 193 |
+
spike, V_post = plif_rowparam_forward(
|
| 194 |
+
beta_row, u, v_th_row, v_init,
|
| 195 |
+
surrogate_function=input_neuron.surrogate_function,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
input_neuron.v = V_post[-1].detach()
|
| 199 |
+
return (1.0 - beta) * V_post # 膜电位泄漏量
|
| 200 |
+
|
| 201 |
+
def _adaptive_aggregate(self, frames, halt_proj):
|
| 202 |
+
"""
|
| 203 |
+
PonderNet 式自适应 K 帧聚合(动态 K 核心,torch.compile 融合优化)。
|
| 204 |
+
|
| 205 |
+
每步计算停止概率 p_k,用几何分布权重加权聚合,
|
| 206 |
+
使不同 token 有不同的有效步数。
|
| 207 |
+
|
| 208 |
+
优化: _fused_geometric_halt 将 sigmoid+log1p+cumsum+exp+normalize
|
| 209 |
+
融合为单 inductor kernel(参见 snn_block._fused_modulation 同一模式)。
|
| 210 |
+
|
| 211 |
+
数学:
|
| 212 |
+
p_k = σ(halt_proj(frame_k)) — 停止概率
|
| 213 |
+
S_k = ∏_{j<k} (1-p_j) — 生存概率
|
| 214 |
+
λ_k = p_k · S_k — 几何分布权重
|
| 215 |
+
λ̂_k = λ_k / Σ λ_k — 归一化
|
| 216 |
+
output = Σ λ̂_k · frame_k — 加权聚合
|
| 217 |
+
E[K] = Σ k · λ̂_k — 期望步数(ponder cost)
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
frames: (seq_len, K, batch, D) — SNN 子层 K 帧输出
|
| 221 |
+
halt_proj: nn.Linear(D, 1) — 停止投影(突触)
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
aggregated: (seq_len, batch, D) — 加权聚合结果
|
| 225 |
+
ponder_cost: scalar — 期望步数均值(正则化用)
|
| 226 |
+
"""
|
| 227 |
+
seq_len, K, batch, D = frames.shape
|
| 228 |
+
|
| 229 |
+
# ====== 1. halt_proj matmul(cuBLAS)+ 融合几何权重(inductor) ======
|
| 230 |
+
halt_logits = halt_proj(frames).squeeze(-1) # (seq_len, K, batch)
|
| 231 |
+
halt_weights = _fused_geometric_halt(halt_logits) # (seq_len, K, batch), 归一化
|
| 232 |
+
|
| 233 |
+
# ====== 2. 加权聚合 ======
|
| 234 |
+
# (seq_len, K, batch, 1) × (seq_len, K, batch, D) → sum → (seq_len, batch, D)
|
| 235 |
+
aggregated = (frames * halt_weights.unsqueeze(-1)).sum(dim=1)
|
| 236 |
+
|
| 237 |
+
# ====== 3. Ponder cost: E[K] per token ======
|
| 238 |
+
steps = torch.arange(1, K + 1, device=frames.device, dtype=frames.dtype)
|
| 239 |
+
expected_k = (halt_weights * steps[None, :, None]).sum(dim=1) # (seq_len, batch)
|
| 240 |
+
ponder_cost = expected_k.mean() # scalar
|
| 241 |
+
|
| 242 |
+
return aggregated, ponder_cost, expected_k.detach()
|
| 243 |
+
|
| 244 |
+
def forward_parallel(self, h):
|
| 245 |
+
"""
|
| 246 |
+
并行前向传播:连续残差流 + 动态 K 帧聚合。
|
| 247 |
+
|
| 248 |
+
SNN 子层在 TK 维度处理(K 步时间动力学),输出后用 PonderNet
|
| 249 |
+
自适应聚合 K 帧(不同 token 有效步数不同),经 out_proj 投影后
|
| 250 |
+
广播回 TK 做残差。
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
h: (TK, batch, D) — 连续值输入
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
h: (TK, batch, D) — 连续值输出
|
| 257 |
+
ponder_cost: scalar — 两个子层的平均期望步数(正则化用)
|
| 258 |
+
"""
|
| 259 |
+
TK, batch, D = h.shape
|
| 260 |
+
K = self.K
|
| 261 |
+
seq_len = TK // K
|
| 262 |
+
|
| 263 |
+
# 子层 1: SNNBlock — RMSNorm → PLIFNode(V_post) → SNNBlock → 动态K聚合 → out_proj → 残差
|
| 264 |
+
v_in = self._input_neuron_parallel(self.input_neuron1, self.block_norm(h))
|
| 265 |
+
cont_block = self.snn_block.forward_parallel(v_in) # (TK, batch, D), 连续值
|
| 266 |
+
|
| 267 |
+
# 动态 K 帧聚合(PonderNet): (TK, batch, D) → (seq_len, K, batch, D) → 加权 → (seq_len, batch, D)
|
| 268 |
+
frames_block = cont_block.view(seq_len, K, batch, D)
|
| 269 |
+
combined_block, pc_block, ek_block = self._adaptive_aggregate(frames_block, self.block_halt)
|
| 270 |
+
res_block = self.block_out_proj(combined_block) # (seq_len, batch, D)
|
| 271 |
+
res_block = res_block - res_block.mean(dim=-1, keepdim=True) # 残差中心化
|
| 272 |
+
|
| 273 |
+
# 广播回 TK:每 token 的残差复制 K 份
|
| 274 |
+
h = h + res_block.repeat_interleave(K, dim=0)
|
| 275 |
+
|
| 276 |
+
# 子层 2: SNNFFN — RMSNorm → PLIFNode(V_post) → SNNFFN → 动态K聚合 → out_proj → 残差
|
| 277 |
+
v_in2 = self._input_neuron_parallel(self.input_neuron2, self.ffn_norm(h))
|
| 278 |
+
cont_ffn = self.snn_ffn.forward_parallel(v_in2) # (TK, batch, D), 连续值
|
| 279 |
+
|
| 280 |
+
frames_ffn = cont_ffn.view(seq_len, K, batch, D)
|
| 281 |
+
combined_ffn, pc_ffn, ek_ffn = self._adaptive_aggregate(frames_ffn, self.ffn_halt)
|
| 282 |
+
res_ffn = self.ffn_out_proj(combined_ffn)
|
| 283 |
+
res_ffn = res_ffn - res_ffn.mean(dim=-1, keepdim=True)
|
| 284 |
+
|
| 285 |
+
h = h + res_ffn.repeat_interleave(K, dim=0)
|
| 286 |
+
|
| 287 |
+
ponder_cost = (pc_block + pc_ffn) / 2.0 # 两个子层平均
|
| 288 |
+
|
| 289 |
+
# 存储 per-token E[K] 范围(诊断用,不影响计算图)
|
| 290 |
+
# ek_block/ek_ffn: (seq_len, batch), detached
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
all_ek = torch.cat([ek_block.flatten(), ek_ffn.flatten()])
|
| 293 |
+
self._ek_min = all_ek.min().item()
|
| 294 |
+
self._ek_max = all_ek.max().item()
|
| 295 |
+
|
| 296 |
+
return h, ponder_cost
|
| 297 |
+
|
| 298 |
+
def single_step_forward(self, h):
|
| 299 |
+
"""
|
| 300 |
+
单步前向传播:连续残差流。
|
| 301 |
+
|
| 302 |
+
注意:单步模式无法做动态 K 聚合(每步独立处理)。
|
| 303 |
+
训练和推理均使用 forward_parallel(含动态 K 聚合)。
|
| 304 |
+
此方法仅用于调试。
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
h: (batch, D) — 连续值输入
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
h: (batch, D) — 连续值输出
|
| 311 |
+
ponder_cost: scalar — 0.0(单步无 ponder cost)
|
| 312 |
+
"""
|
| 313 |
+
# 子层 1: SNNBlock — RMSNorm → PLIFNode(leak) → SNNBlock → out_proj → 残差
|
| 314 |
+
_ = self.input_neuron1(self.block_norm(h)) # 触发 PLIF 动力学,更新 .v
|
| 315 |
+
v_in = (1.0 - self.input_neuron1.beta) * self.input_neuron1.v # 膜电位泄漏量
|
| 316 |
+
cont_block = self.snn_block.single_step_forward(v_in)
|
| 317 |
+
res_block = self.block_out_proj(cont_block)
|
| 318 |
+
h = h + res_block - res_block.mean(dim=-1, keepdim=True)
|
| 319 |
+
|
| 320 |
+
# 子层 2: SNNFFN — RMSNorm → PLIFNode(leak) → SNNFFN → out_proj → 残差
|
| 321 |
+
_ = self.input_neuron2(self.ffn_norm(h))
|
| 322 |
+
v_in2 = (1.0 - self.input_neuron2.beta) * self.input_neuron2.v # 膜电位泄漏量
|
| 323 |
+
cont_ffn = self.snn_ffn.single_step_forward(v_in2)
|
| 324 |
+
res_ffn = self.ffn_out_proj(cont_ffn)
|
| 325 |
+
h = h + res_ffn - res_ffn.mean(dim=-1, keepdim=True)
|
| 326 |
+
|
| 327 |
+
return h, torch.tensor(0.0, device=h.device)
|
atomic_ops/snn_ffn.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SNNFFN: SNN 等价的 Feed-Forward Network
|
| 3 |
+
|
| 4 |
+
对标 Qwen3MLP 的 SwiGLU 结构:
|
| 5 |
+
Qwen3 MLP: down_proj( SiLU(gate_proj(x)) * up_proj(x) )
|
| 6 |
+
SNN FFN: down_proj( gate_V_post * up_V_post ) + skip
|
| 7 |
+
|
| 8 |
+
膜电位门控(对标 SiLU gating):
|
| 9 |
+
gate/up 神经元完整 PLIF 动力学(积分+阈值+重置),
|
| 10 |
+
输出膜电位 V_post 做连续乘法门控,替代 binary AND 门。
|
| 11 |
+
|
| 12 |
+
信号流:
|
| 13 |
+
x → gate_proj → gate_neuron → V_post_gate
|
| 14 |
+
x → up_proj → up_neuron → V_post_up
|
| 15 |
+
V_post_gate × V_post_up → gated
|
| 16 |
+
down_proj(gated) + skip_proj(x) → 连续输出
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from spikingjelly.activation_based import base, layer, surrogate
|
| 24 |
+
|
| 25 |
+
from .plif_node import PLIFNode
|
| 26 |
+
from .parallel_scan import plif_rowparam_forward
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SNNFFN(base.MemoryModule):
|
| 30 |
+
"""
|
| 31 |
+
SNN 等价的 Feed-Forward Network。
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
D: 可见维度(输入/输出 spike 维度)
|
| 35 |
+
D_ff: 中间层维度(对标 Qwen3 intermediate_size)
|
| 36 |
+
output_v_threshold: 输出神经元阈值
|
| 37 |
+
num_layers: 总层数,用于 down_proj 缩放
|
| 38 |
+
layer_idx: 当前层索引
|
| 39 |
+
surrogate_function: surrogate gradient 函数
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
D: int,
|
| 45 |
+
D_ff: int,
|
| 46 |
+
output_v_threshold: float = 0.3,
|
| 47 |
+
num_layers: int = 1,
|
| 48 |
+
layer_idx: int = 0,
|
| 49 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.D = D
|
| 53 |
+
self.D_ff = D_ff
|
| 54 |
+
|
| 55 |
+
# ====== 三条投影路径(对标 SwiGLU: gate_proj, up_proj, down_proj) ======
|
| 56 |
+
self.gate_proj = layer.Linear(D, D_ff, bias=False, step_mode='s')
|
| 57 |
+
self.up_proj = layer.Linear(D, D_ff, bias=False, step_mode='s')
|
| 58 |
+
self.down_proj = layer.Linear(D_ff, D, bias=False, step_mode='s')
|
| 59 |
+
|
| 60 |
+
# ====== 残差路径 ======
|
| 61 |
+
self.skip_proj = layer.Linear(D, D, bias=False, step_mode='s')
|
| 62 |
+
|
| 63 |
+
# ====== 神经元(D 维或 D_ff 维可学习 β 和 V_th) ======
|
| 64 |
+
# gate_neuron: 门控发放
|
| 65 |
+
self.gate_neuron = PLIFNode(
|
| 66 |
+
dim=D_ff,
|
| 67 |
+
init_tau=2.0,
|
| 68 |
+
v_threshold=output_v_threshold,
|
| 69 |
+
surrogate_function=surrogate_function,
|
| 70 |
+
)
|
| 71 |
+
# up_neuron: 值发放
|
| 72 |
+
self.up_neuron = PLIFNode(
|
| 73 |
+
dim=D_ff,
|
| 74 |
+
init_tau=2.0,
|
| 75 |
+
v_threshold=output_v_threshold,
|
| 76 |
+
surrogate_function=surrogate_function,
|
| 77 |
+
)
|
| 78 |
+
# ====== 参数初始化 ======
|
| 79 |
+
self._initialize_parameters(num_layers)
|
| 80 |
+
|
| 81 |
+
def _initialize_parameters(self, num_layers: int):
|
| 82 |
+
"""初始化投影权重。
|
| 83 |
+
|
| 84 |
+
- gate_proj, up_proj, skip_proj: Kaiming uniform
|
| 85 |
+
- down_proj: Kaiming uniform × 1/√(num_layers),防深层梯度爆炸
|
| 86 |
+
"""
|
| 87 |
+
for lin in [self.gate_proj, self.up_proj, self.skip_proj]:
|
| 88 |
+
nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
|
| 89 |
+
|
| 90 |
+
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
|
| 91 |
+
self.down_proj.weight.data.mul_(1.0 / math.sqrt(num_layers))
|
| 92 |
+
|
| 93 |
+
def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
"""
|
| 95 |
+
并行前向传播:使用 parallel scan 处理全序列。
|
| 96 |
+
|
| 97 |
+
优化:
|
| 98 |
+
- gate_proj + up_proj 合并为单次 matmul(2 launch → 1)
|
| 99 |
+
- gate + up PLIF scan: row-param kernel(无需 expand+contiguous beta/v_th)
|
| 100 |
+
- u_merged: 向量缩放替代 cat(1次 broadcast multiply 替代 2次 scale + 1次 cat)
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出
|
| 107 |
+
"""
|
| 108 |
+
TK, batch, D = spike_in_seq.shape
|
| 109 |
+
D_ff = self.D_ff
|
| 110 |
+
flat = spike_in_seq.reshape(TK * batch, D)
|
| 111 |
+
|
| 112 |
+
# ====== Phase 1: 批量投影(gate+up 合并为 1 次 matmul) ======
|
| 113 |
+
W_gate_up = torch.cat([self.gate_proj.weight, self.up_proj.weight], dim=0)
|
| 114 |
+
I_gate_up = F.linear(flat, W_gate_up).reshape(TK, batch, 2 * D_ff)
|
| 115 |
+
I_skip = F.linear(flat, self.skip_proj.weight).reshape(TK, batch, D)
|
| 116 |
+
|
| 117 |
+
# ====== Phase 2: Gate+Up 合并 PLIF scan(row-param kernel) ======
|
| 118 |
+
beta_gate = self.gate_neuron.beta # (D_ff,)
|
| 119 |
+
beta_up = self.up_neuron.beta # (D_ff,)
|
| 120 |
+
surr = self.gate_neuron.surrogate_function
|
| 121 |
+
|
| 122 |
+
# u_merged: 向量缩放(D_ff 维 β 直接 cat,无需 expand)
|
| 123 |
+
scale_row = torch.cat([1.0 - beta_gate, 1.0 - beta_up]) # (2*D_ff,)
|
| 124 |
+
u_merged = I_gate_up * scale_row # (TK, batch, 2*D_ff), broadcast
|
| 125 |
+
|
| 126 |
+
# beta_row / v_th_row: (batch, 2*D_ff) — D_ff 维可学习参数
|
| 127 |
+
beta_row = torch.cat([beta_gate, beta_up]) # (2*D_ff,)
|
| 128 |
+
beta_row = beta_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous()
|
| 129 |
+
|
| 130 |
+
v_th_row = torch.cat([self.gate_neuron.v_th, self.up_neuron.v_th]) # (2*D_ff,)
|
| 131 |
+
v_th_row = v_th_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous()
|
| 132 |
+
|
| 133 |
+
# v_init_merged: (batch, 2*D_ff)
|
| 134 |
+
v_init_gate = self.gate_neuron.v
|
| 135 |
+
if isinstance(v_init_gate, float):
|
| 136 |
+
v_init_gate = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype)
|
| 137 |
+
v_init_up = self.up_neuron.v
|
| 138 |
+
if isinstance(v_init_up, float):
|
| 139 |
+
v_init_up = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype)
|
| 140 |
+
v_init_merged = torch.cat([v_init_gate, v_init_up], dim=-1)
|
| 141 |
+
|
| 142 |
+
# Row-param PLIF scan: beta/v_th 从寄存器读取,不占显存带宽
|
| 143 |
+
spike_merged, V_post_merged = plif_rowparam_forward(
|
| 144 |
+
beta_row, u_merged, v_th_row, v_init_merged,
|
| 145 |
+
surrogate_function=surr,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# 膜电位泄漏量作为激活值: leak = (1-β) · V_post
|
| 149 |
+
gate_leak = V_post_merged[:, :, :D_ff] * (1.0 - beta_gate) # (TK, batch, D_ff)
|
| 150 |
+
up_leak = V_post_merged[:, :, D_ff:] * (1.0 - beta_up) # (TK, batch, D_ff)
|
| 151 |
+
self.gate_neuron.v = V_post_merged[-1, :, :D_ff].detach()
|
| 152 |
+
self.up_neuron.v = V_post_merged[-1, :, D_ff:].detach()
|
| 153 |
+
|
| 154 |
+
# ====== Phase 3: 连续门控(leak × leak,对标 SwiGLU)+ 降维 ======
|
| 155 |
+
gated = gate_leak * up_leak # (TK, batch, D_ff)
|
| 156 |
+
gated_flat = gated.reshape(TK * batch, D_ff)
|
| 157 |
+
I_out = F.linear(gated_flat, self.down_proj.weight).reshape(TK, batch, D) + I_skip
|
| 158 |
+
|
| 159 |
+
# output_neuron 已移除:连续值由层级 K 帧聚合处理
|
| 160 |
+
return I_out # (TK, batch, D), 连续值
|
| 161 |
+
|
| 162 |
+
def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
"""
|
| 164 |
+
单步前向传播。
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1}
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
continuous_out: 连续输出, shape (batch, D)
|
| 171 |
+
"""
|
| 172 |
+
# 门控路径 — 膜电位泄漏量激活
|
| 173 |
+
_ = self.gate_neuron(self.gate_proj(spike_in))
|
| 174 |
+
gate_leak = (1.0 - self.gate_neuron.beta) * self.gate_neuron.v # leak
|
| 175 |
+
|
| 176 |
+
# 值路径 — 膜电位泄漏量激活
|
| 177 |
+
_ = self.up_neuron(self.up_proj(spike_in))
|
| 178 |
+
up_leak = (1.0 - self.up_neuron.beta) * self.up_neuron.v # leak
|
| 179 |
+
|
| 180 |
+
# 连续门控(对标 SwiGLU)
|
| 181 |
+
gated = gate_leak * up_leak
|
| 182 |
+
|
| 183 |
+
# 降维 + 残差
|
| 184 |
+
I_out = self.down_proj(gated) + self.skip_proj(spike_in) # R^D
|
| 185 |
+
return I_out # 连续值
|
config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "neuronspark",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"NeuronSparkForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_neuronspark.NeuronSparkConfig",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_neuronspark.NeuronSparkForCausalLM"
|
| 9 |
+
},
|
| 10 |
+
"vocab_size": 6144,
|
| 11 |
+
"D": 896,
|
| 12 |
+
"N": 8,
|
| 13 |
+
"K": 16,
|
| 14 |
+
"num_layers": 20,
|
| 15 |
+
"D_ff": 2688,
|
| 16 |
+
"v_th_min": 0.1,
|
| 17 |
+
"torch_dtype": "float32",
|
| 18 |
+
"transformers_version": "4.52.0",
|
| 19 |
+
"_training_step": 6500,
|
| 20 |
+
"_tokens_seen": 10459816
|
| 21 |
+
}
|
configuration.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"framework": "pytorch", "task": "other"}
|
configuration_neuronspark.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""NeuronSpark 模型配置。"""
|
| 2 |
+
|
| 3 |
+
from transformers import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class NeuronSparkConfig(PretrainedConfig):
|
| 7 |
+
"""
|
| 8 |
+
SNN 隐状态空间语言模型配置。
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
vocab_size: 词表大小
|
| 12 |
+
D: 隐层维度
|
| 13 |
+
N: 状态扩展因子(每通道隐神经元数)
|
| 14 |
+
K: 每 token 最大 SNN 时间步(PonderNet 动态决定有效步数)
|
| 15 |
+
num_layers: SNN 解码层数
|
| 16 |
+
D_ff: FFN 中间层维度
|
| 17 |
+
v_th_min: 动态阈值下限
|
| 18 |
+
"""
|
| 19 |
+
model_type = "neuronspark"
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
vocab_size=6144,
|
| 24 |
+
D=896,
|
| 25 |
+
N=8,
|
| 26 |
+
K=16,
|
| 27 |
+
num_layers=20,
|
| 28 |
+
D_ff=2688,
|
| 29 |
+
v_th_min=0.1,
|
| 30 |
+
**kwargs,
|
| 31 |
+
):
|
| 32 |
+
self.D = D
|
| 33 |
+
self.N = N
|
| 34 |
+
self.K = K
|
| 35 |
+
self.num_layers = num_layers
|
| 36 |
+
self.D_ff = D_ff
|
| 37 |
+
self.v_th_min = v_th_min
|
| 38 |
+
super().__init__(vocab_size=vocab_size, **kwargs)
|
model.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SNNLanguageModel: SNN 隐状态空间语言模型(全膜电位 + 动态 K)
|
| 3 |
+
|
| 4 |
+
架构(三段式):
|
| 5 |
+
model.encode(token_ids) → h_seq # 输入: embed → repeat K 次(可微分)
|
| 6 |
+
model.snn_forward(h_seq) → h_out, pc # SNN 核心: 20 层,全膜电位 + 动态 K 聚合
|
| 7 |
+
model.decode(h_out, seq) → logits # 输出: output_neuron(V_post) → K帧mean → proj → logits
|
| 8 |
+
|
| 9 |
+
核心设计:
|
| 10 |
+
1. 膜电位泄漏量:PLIFNode 输出 (1-β)·V_post(泄漏量),自然强调快响应神经元
|
| 11 |
+
2. 动态 K:PonderNet 自适应停止,不同 token 不同有效步数
|
| 12 |
+
- 每层每子层学习 halt_proj(D→1),从 SNN 输出逐步计算停止概率
|
| 13 |
+
- 几何分布权重加权聚合,替代 uniform mean
|
| 14 |
+
- ponder_cost 正则化鼓励早停
|
| 15 |
+
|
| 16 |
+
数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from spikingjelly.activation_based import functional, surrogate
|
| 27 |
+
from torch.utils.checkpoint import checkpoint
|
| 28 |
+
|
| 29 |
+
from atomic_ops import SNNDecoderLayer
|
| 30 |
+
from atomic_ops.plif_node import PLIFNode
|
| 31 |
+
from atomic_ops.rms_norm import RMSNorm
|
| 32 |
+
from atomic_ops.parallel_scan import plif_rowparam_forward
|
| 33 |
+
# fp16_encode/fp16_decode 已移除: 全膜电位架构不需要 spike 编解码
|
| 34 |
+
from atomic_ops.lateral_inhibition import LateralInhibition
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class SNNModelOutput:
|
| 39 |
+
"""模型输出容器,对齐教程 CausalLMOutputWithPast 接口。"""
|
| 40 |
+
last_loss: Optional[torch.Tensor] = None
|
| 41 |
+
logits: Optional[torch.Tensor] = None
|
| 42 |
+
ponder_cost: Optional[torch.Tensor] = None # 动态 K: 平均期望步数
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SNNLanguageModel(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
从零训练的 SNN 隐状态空间语言模型(parallel scan)。
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
vocab_size: 词表大小(默认 6144,自训练 BPE)
|
| 51 |
+
D: 可见维度
|
| 52 |
+
N: 状态扩展因子
|
| 53 |
+
K: 每 token 最大 SNN 时间步数(K_max)。PonderNet 动态决定有效步数 ∈ [1, K]。
|
| 54 |
+
K 越大 → 复杂 token 可用更多步数,但计算量和显存线性增长。
|
| 55 |
+
num_layers: SNN 解码层数
|
| 56 |
+
D_ff: FFN 中间层维度
|
| 57 |
+
v_th_min: 动态阈值下限
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
vocab_size: int = 6144,
|
| 63 |
+
D: int = 1024,
|
| 64 |
+
N: int = 8,
|
| 65 |
+
K: int = 32,
|
| 66 |
+
num_layers: int = 20,
|
| 67 |
+
D_ff: int = 3072,
|
| 68 |
+
v_th_min: float = 0.1,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.vocab_size = vocab_size
|
| 72 |
+
self.D = D
|
| 73 |
+
self.N = N
|
| 74 |
+
self.K = K
|
| 75 |
+
self.num_layers = num_layers
|
| 76 |
+
self.D_ff = D_ff
|
| 77 |
+
|
| 78 |
+
# ====== Embedding + Norm(全部可训练)======
|
| 79 |
+
self.embed_tokens = nn.Embedding(vocab_size, D)
|
| 80 |
+
self.norm = LateralInhibition(D)
|
| 81 |
+
|
| 82 |
+
# ====== 解码投影 ======
|
| 83 |
+
self.decode_proj = nn.Linear(D, D)
|
| 84 |
+
|
| 85 |
+
# ====== 输出 RMSNorm + 输出神经元 ======
|
| 86 |
+
self.output_norm = RMSNorm(D)
|
| 87 |
+
self.output_neuron = PLIFNode(
|
| 88 |
+
dim=D,
|
| 89 |
+
init_tau=2.0,
|
| 90 |
+
v_threshold=0.3,
|
| 91 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# ====== SNN Decoder Layers ======
|
| 95 |
+
self.layers = nn.ModuleList([
|
| 96 |
+
SNNDecoderLayer(
|
| 97 |
+
D=D, N=N, D_ff=D_ff, v_th_min=v_th_min,
|
| 98 |
+
ffn_v_threshold=0.15,
|
| 99 |
+
K=K,
|
| 100 |
+
num_layers=num_layers,
|
| 101 |
+
layer_idx=i,
|
| 102 |
+
)
|
| 103 |
+
for i in range(num_layers)
|
| 104 |
+
])
|
| 105 |
+
|
| 106 |
+
self._init_weights()
|
| 107 |
+
|
| 108 |
+
def _init_weights(self):
|
| 109 |
+
"""初始化所有可训练权重(从零训练)。"""
|
| 110 |
+
nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=0.02)
|
| 111 |
+
nn.init.xavier_uniform_(self.decode_proj.weight)
|
| 112 |
+
nn.init.zeros_(self.decode_proj.bias)
|
| 113 |
+
|
| 114 |
+
def encode(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
"""输入边界:token_ids → 连续值序列。
|
| 116 |
+
|
| 117 |
+
Embedding lookup,每 token 重复 K 次作为 SNN 时间步输入。
|
| 118 |
+
梯度可通过 embedding 直接反传。
|
| 119 |
+
|
| 120 |
+
Returns: (seq_len*K, batch, D), 连续值
|
| 121 |
+
"""
|
| 122 |
+
emb = self.embed_tokens(token_ids) # (batch, seq_len, D)
|
| 123 |
+
batch, seq_len, D = emb.shape
|
| 124 |
+
# 每 token 重复 K 次: (batch, seq_len, D) → (batch, seq_len*K, D) → (TK, batch, D)
|
| 125 |
+
emb_k = emb.unsqueeze(2).expand(-1, -1, self.K, -1).reshape(batch, seq_len * self.K, D)
|
| 126 |
+
return emb_k.permute(1, 0, 2).contiguous() # (TK, batch, D)
|
| 127 |
+
|
| 128 |
+
def snn_forward(self, spike_seq: torch.Tensor):
|
| 129 |
+
"""SNN 核心:spike_seq → (h_out, ponder_cost)。
|
| 130 |
+
|
| 131 |
+
纯 SNN 层计算,带梯度检查点。
|
| 132 |
+
每层返回 (h, ponder_cost),ponder_cost 作为 checkpoint 输出保留梯度图。
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
h: (seq_len*K, batch, D), 连续值
|
| 136 |
+
total_ponder_cost: scalar, 所有层平均期望步数
|
| 137 |
+
"""
|
| 138 |
+
h = spike_seq
|
| 139 |
+
ponder_costs = []
|
| 140 |
+
|
| 141 |
+
def _layer_forward(layer_mod, x):
|
| 142 |
+
functional.reset_net(layer_mod)
|
| 143 |
+
return layer_mod.forward_parallel(x) # returns (h, ponder_cost)
|
| 144 |
+
|
| 145 |
+
for layer_module in self.layers:
|
| 146 |
+
h, pc = checkpoint(
|
| 147 |
+
_layer_forward, layer_module, h,
|
| 148 |
+
use_reentrant=False,
|
| 149 |
+
)
|
| 150 |
+
ponder_costs.append(pc)
|
| 151 |
+
|
| 152 |
+
total_ponder_cost = sum(ponder_costs) / len(ponder_costs)
|
| 153 |
+
return h, total_ponder_cost
|
| 154 |
+
|
| 155 |
+
def _output_neuron_parallel(self, h: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
"""输出 PLIF 神经元的 parallel scan 前向:连续 h → 膜电位泄漏量。
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
h: (TK, batch, D) 连续值(SNN 最后一层输出)
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
leak: (TK, batch, D) 膜电位泄漏量 (1-β)·V_post
|
| 163 |
+
"""
|
| 164 |
+
TK, batch, D = h.shape
|
| 165 |
+
|
| 166 |
+
beta = self.output_neuron.beta # (D,)
|
| 167 |
+
u = (1.0 - beta) * h # PLIF: u = (1-β) · x
|
| 168 |
+
|
| 169 |
+
v_init = self.output_neuron.v
|
| 170 |
+
if isinstance(v_init, float):
|
| 171 |
+
v_init = torch.zeros(batch, D, device=h.device, dtype=h.dtype)
|
| 172 |
+
|
| 173 |
+
beta_row = beta.unsqueeze(0).expand(batch, D).contiguous()
|
| 174 |
+
v_th_row = self.output_neuron.v_th.unsqueeze(0).expand(batch, D).contiguous()
|
| 175 |
+
|
| 176 |
+
spike, V_post = plif_rowparam_forward(
|
| 177 |
+
beta_row, u, v_th_row, v_init,
|
| 178 |
+
surrogate_function=self.output_neuron.surrogate_function,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
self.output_neuron.v = V_post[-1].detach()
|
| 182 |
+
return (1.0 - beta) * V_post # 膜电位泄漏量
|
| 183 |
+
|
| 184 |
+
def decode(self, h_out: torch.Tensor, seq_len: int) -> torch.Tensor:
|
| 185 |
+
"""输出边界:连续 h → 输出神经元(V_post) → K 帧聚合 → logits。
|
| 186 |
+
|
| 187 |
+
梯度流: loss → logits → norm → decode_proj → K帧mean
|
| 188 |
+
→ V_post(output_neuron) → h_out → SNN layers
|
| 189 |
+
|
| 190 |
+
Returns: (batch, seq_len, vocab_size)
|
| 191 |
+
"""
|
| 192 |
+
h_out = self.output_norm(h_out) # RMSNorm: 控制 scale
|
| 193 |
+
v_out = self._output_neuron_parallel(h_out) # (TK, batch, D), V_post 膜电位
|
| 194 |
+
# K 帧聚合: (TK, batch, D) → (seq_len, K, batch, D) → mean → (seq_len, batch, D)
|
| 195 |
+
decoded = v_out.view(seq_len, self.K, -1, self.D).mean(dim=1)
|
| 196 |
+
decoded = decoded.permute(1, 0, 2) # (batch, seq_len, D)
|
| 197 |
+
h = self.decode_proj(decoded) # (batch, seq_len, D)
|
| 198 |
+
h = self.norm(h) # (batch, seq_len, D)
|
| 199 |
+
return F.linear(h, self.embed_tokens.weight) # (batch, seq_len, vocab)
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def generate(
|
| 203 |
+
self,
|
| 204 |
+
prompt_ids: torch.Tensor,
|
| 205 |
+
max_new_tokens: int,
|
| 206 |
+
temperature: float = 1.0,
|
| 207 |
+
top_k: int = 50,
|
| 208 |
+
eos_token_id: Optional[int] = None,
|
| 209 |
+
) -> torch.Tensor:
|
| 210 |
+
"""
|
| 211 |
+
自回归生成(SNN 神经元状态跨 token 连续维护)。
|
| 212 |
+
|
| 213 |
+
1. Prefill: forward_parallel 并行处理 prompt,建立所有神经元 V 状态
|
| 214 |
+
2. Autoregressive: 逐 token 生成,每 token 用 forward_parallel 处理 K 帧
|
| 215 |
+
复用 Triton parallel scan kernel,神经元 V 状态跨 token 连续传递
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
prompt_ids: (batch, prompt_len) token IDs
|
| 219 |
+
max_new_tokens: 最大生成 token 数
|
| 220 |
+
temperature: 采样温度(<=0 = greedy)
|
| 221 |
+
top_k: top-k 采样(None/0 = 不限制)
|
| 222 |
+
eos_token_id: 遇到此 token 停止生成
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
(batch, prompt_len + generated_len) 完整序列
|
| 226 |
+
"""
|
| 227 |
+
batch, prompt_len = prompt_ids.shape
|
| 228 |
+
|
| 229 |
+
# 重置所有神经元(新序列的初始条件 V=0)
|
| 230 |
+
for layer_module in self.layers:
|
| 231 |
+
functional.reset_net(layer_module)
|
| 232 |
+
functional.reset_net(self.output_neuron)
|
| 233 |
+
|
| 234 |
+
# ====== Prefill: parallel 处理整个 prompt ======
|
| 235 |
+
h_seq = self.encode(prompt_ids) # (prompt_len*K, batch, D), 连续值
|
| 236 |
+
h = h_seq
|
| 237 |
+
for layer_module in self.layers:
|
| 238 |
+
h, _ = layer_module.forward_parallel(h) # 推理忽略 ponder_cost
|
| 239 |
+
# 此时所有层的所有神经元 .v 状态 = prompt 末尾状态
|
| 240 |
+
|
| 241 |
+
logits = self.decode(h, prompt_len)
|
| 242 |
+
|
| 243 |
+
# 采样第一个新 token
|
| 244 |
+
next_token = self._sample(logits[:, -1, :], temperature, top_k)
|
| 245 |
+
generated = [next_token]
|
| 246 |
+
|
| 247 |
+
# ====== Autoregressive: 逐 token,forward_parallel 处理 K 帧 ======
|
| 248 |
+
for _ in range(max_new_tokens - 1):
|
| 249 |
+
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 250 |
+
break
|
| 251 |
+
|
| 252 |
+
# 编码单 token → K 帧连续值(复用 encode)
|
| 253 |
+
frames = self.encode(next_token) # (K, batch, D)
|
| 254 |
+
|
| 255 |
+
# K 帧通过 SNN — 不 reset,神经元 .v 跨 token 连续传递
|
| 256 |
+
h = frames
|
| 257 |
+
for layer_module in self.layers:
|
| 258 |
+
h, _ = layer_module.forward_parallel(h)
|
| 259 |
+
|
| 260 |
+
logits = self.decode(h, 1)
|
| 261 |
+
|
| 262 |
+
next_token = self._sample(logits[:, -1, :], temperature, top_k)
|
| 263 |
+
generated.append(next_token)
|
| 264 |
+
|
| 265 |
+
return torch.cat([prompt_ids, torch.cat(generated, dim=1)], dim=1)
|
| 266 |
+
|
| 267 |
+
def _sample(self, logits: torch.Tensor, temperature: float = 1.0, top_k: int = None) -> torch.Tensor:
|
| 268 |
+
"""从 logits 采样(temperature + top-k)。
|
| 269 |
+
|
| 270 |
+
Returns: (batch, 1)
|
| 271 |
+
"""
|
| 272 |
+
if temperature <= 0:
|
| 273 |
+
return logits.argmax(dim=-1, keepdim=True)
|
| 274 |
+
logits = logits / temperature
|
| 275 |
+
if top_k is not None and top_k > 0:
|
| 276 |
+
top_k = min(top_k, logits.size(-1))
|
| 277 |
+
v, _ = torch.topk(logits, top_k)
|
| 278 |
+
logits[logits < v[:, [-1]]] = float('-inf')
|
| 279 |
+
probs = F.softmax(logits, dim=-1)
|
| 280 |
+
return torch.multinomial(probs, num_samples=1)
|
| 281 |
+
|
| 282 |
+
def forward(
|
| 283 |
+
self,
|
| 284 |
+
token_ids: torch.Tensor,
|
| 285 |
+
target_ids: torch.Tensor = None,
|
| 286 |
+
) -> SNNModelOutput:
|
| 287 |
+
"""
|
| 288 |
+
前向传播(全膜电位 + 动态 K)。
|
| 289 |
+
|
| 290 |
+
encode → h_seq # 输入(embed repeat K 次,可微分)
|
| 291 |
+
snn_forward → h_out, pc # SNN 核心(全膜电位 + 动态 K 聚合)
|
| 292 |
+
decode → logits # 输出(V_post → K帧mean → proj → logits)
|
| 293 |
+
|
| 294 |
+
梯度流:
|
| 295 |
+
embed_tokens → repeat K → SNN layers(V_post + 动态K)
|
| 296 |
+
→ output_neuron(V_post) → K帧mean → decode_proj → logits(tied head)
|
| 297 |
+
ponder_cost: 动态 K 正则化,鼓励用更少步数处理简单 token
|
| 298 |
+
"""
|
| 299 |
+
batch, seq_len = token_ids.shape
|
| 300 |
+
|
| 301 |
+
# 重置所有神经元状态
|
| 302 |
+
for layer_module in self.layers:
|
| 303 |
+
functional.reset_net(layer_module)
|
| 304 |
+
functional.reset_net(self.output_neuron)
|
| 305 |
+
|
| 306 |
+
# 三段式
|
| 307 |
+
spike_seq = self.encode(token_ids) # 输入边界
|
| 308 |
+
h_out, ponder_cost = self.snn_forward(spike_seq) # SNN 核心 + ponder cost
|
| 309 |
+
logits = self.decode(h_out, seq_len) # 输出边界
|
| 310 |
+
|
| 311 |
+
if target_ids is not None:
|
| 312 |
+
logits_flat = logits.reshape(-1, self.vocab_size)
|
| 313 |
+
targets_flat = target_ids.reshape(-1)
|
| 314 |
+
self.last_loss = F.cross_entropy(
|
| 315 |
+
logits_flat, targets_flat,
|
| 316 |
+
ignore_index=0, reduction='none',
|
| 317 |
+
)
|
| 318 |
+
return SNNModelOutput(
|
| 319 |
+
last_loss=self.last_loss,
|
| 320 |
+
ponder_cost=ponder_cost,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
return SNNModelOutput(logits=logits, ponder_cost=ponder_cost)
|
| 324 |
+
|
| 325 |
+
def compensate_modulation_gradients(self, max_comp: float = 100.0):
|
| 326 |
+
"""
|
| 327 |
+
Natural Gradient 补偿(两阶段)。
|
| 328 |
+
|
| 329 |
+
Phase 1: Sigmoid/softplus 饱和补偿
|
| 330 |
+
β = sigmoid(b_beta), sigmoid 在高 β 区(β=0.99, sigmoid'=0.01)梯度衰减 100x。
|
| 331 |
+
补偿: grad /= activation'(b),等价于在 β/α 空间做梯度下降。
|
| 332 |
+
|
| 333 |
+
Phase 2: 层间梯度均衡
|
| 334 |
+
残差链反向传播每层放大 ~1.17×,20 层累积 ~20× L0/L19 比。
|
| 335 |
+
深层选择性参数(b_beta/b_alpha/b_th)梯度被压制,无法有效学习。
|
| 336 |
+
修复: 将每层调制参数梯度 norm 归一化到所有层的几何均值。
|
| 337 |
+
|
| 338 |
+
调用时机: scaler.unscale_(optimizer) 之后、clip_grad_norm_ 之前。
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
max_comp: 补偿因子上限(防止极端值导致不稳定)
|
| 342 |
+
"""
|
| 343 |
+
# ====== Phase 1: Sigmoid/softplus 饱和补偿 ======
|
| 344 |
+
for layer_module in self.layers:
|
| 345 |
+
block = layer_module.snn_block
|
| 346 |
+
|
| 347 |
+
# b_beta: sigmoid 饱和补偿
|
| 348 |
+
# sigmoid'(z) = sigmoid(z) · (1 - sigmoid(z)) = β · (1-β)
|
| 349 |
+
if block.b_beta.grad is not None:
|
| 350 |
+
with torch.no_grad():
|
| 351 |
+
beta = torch.sigmoid(block.b_beta.data)
|
| 352 |
+
sigmoid_deriv = (beta * (1.0 - beta)).clamp(min=1.0 / max_comp)
|
| 353 |
+
block.b_beta.grad.div_(sigmoid_deriv)
|
| 354 |
+
|
| 355 |
+
# b_alpha: softplus 补偿(较温和,softplus'(z) = sigmoid(z))
|
| 356 |
+
if block.b_alpha.grad is not None:
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
softplus_deriv = torch.sigmoid(block.b_alpha.data).clamp(min=0.1)
|
| 359 |
+
block.b_alpha.grad.div_(softplus_deriv)
|
| 360 |
+
|
| 361 |
+
# b_th: |·| 导数为 ±1,无衰减,不需要补偿
|
| 362 |
+
|
| 363 |
+
# ====== Phase 2: 层间梯度均衡 ======
|
| 364 |
+
# 残差链 h = h + sublayer(h) 的反向路径 ∂h_{l+1}/∂h_l = I + ∂sublayer/∂h_l
|
| 365 |
+
# 每层放大 ~1.17×, 20 层累积 ~20× → L0 梯度远大于 L19
|
| 366 |
+
# 用几何均值归一化每层调制参数梯度 norm,消除残差放大效应
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
for param_name in ['b_beta', 'b_alpha', 'b_th']:
|
| 369 |
+
norms = []
|
| 370 |
+
params_list = []
|
| 371 |
+
for layer_module in self.layers:
|
| 372 |
+
p = getattr(layer_module.snn_block, param_name)
|
| 373 |
+
if p.grad is not None:
|
| 374 |
+
n = p.grad.norm().item()
|
| 375 |
+
if n > 1e-12:
|
| 376 |
+
norms.append(n)
|
| 377 |
+
params_list.append(p)
|
| 378 |
+
|
| 379 |
+
if len(norms) >= 2:
|
| 380 |
+
# 几何均值: exp(mean(log(norms))) — 对数尺度均衡,不受极端值影响
|
| 381 |
+
log_mean = sum(math.log(n) for n in norms) / len(norms)
|
| 382 |
+
geo_mean = math.exp(log_mean)
|
| 383 |
+
for p, n in zip(params_list, norms):
|
| 384 |
+
scale = geo_mean / n
|
| 385 |
+
scale = max(min(scale, max_comp), 1.0 / max_comp)
|
| 386 |
+
p.grad.mul_(scale)
|
| 387 |
+
|
| 388 |
+
def get_param_groups(self) -> dict[str, list[nn.Parameter]]:
|
| 389 |
+
"""
|
| 390 |
+
按功能分组的可训练参数。
|
| 391 |
+
"""
|
| 392 |
+
groups = {
|
| 393 |
+
'embedding': [self.embed_tokens.weight],
|
| 394 |
+
'norm': [self.norm.gain],
|
| 395 |
+
'decode': list(self.decode_proj.parameters()),
|
| 396 |
+
# 输出神经元
|
| 397 |
+
'output_neuron': [self.output_neuron.w, self.output_neuron.v_th],
|
| 398 |
+
# RMSNorm(Pre-LN 分支归一化)
|
| 399 |
+
'rms_norms': [self.output_norm.weight],
|
| 400 |
+
# 残差流组件
|
| 401 |
+
'residual_projs': [],
|
| 402 |
+
'input_neurons': [],
|
| 403 |
+
# 动态 K: 停止投影
|
| 404 |
+
'halt_projs': [],
|
| 405 |
+
# SNNBlock 参数
|
| 406 |
+
'W_in': [],
|
| 407 |
+
'W_beta': [],
|
| 408 |
+
'W_alpha': [],
|
| 409 |
+
'W_th': [],
|
| 410 |
+
'W_gate': [],
|
| 411 |
+
'W_skip': [],
|
| 412 |
+
'W_out': [],
|
| 413 |
+
'b_beta': [],
|
| 414 |
+
'b_alpha': [],
|
| 415 |
+
'b_th': [],
|
| 416 |
+
'block_output_neuron': [],
|
| 417 |
+
# SNNFFN 参数
|
| 418 |
+
'ffn_gate_proj': [],
|
| 419 |
+
'ffn_up_proj': [],
|
| 420 |
+
'ffn_down_proj': [],
|
| 421 |
+
'ffn_skip_proj': [],
|
| 422 |
+
'ffn_neurons': [],
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
for layer_module in self.layers:
|
| 426 |
+
block = layer_module.snn_block
|
| 427 |
+
ffn = layer_module.snn_ffn
|
| 428 |
+
|
| 429 |
+
# 残差流组件
|
| 430 |
+
groups['residual_projs'].extend([
|
| 431 |
+
layer_module.block_out_proj.weight,
|
| 432 |
+
layer_module.ffn_out_proj.weight,
|
| 433 |
+
])
|
| 434 |
+
groups['input_neurons'].extend([
|
| 435 |
+
layer_module.input_neuron1.w,
|
| 436 |
+
layer_module.input_neuron1.v_th,
|
| 437 |
+
layer_module.input_neuron2.w,
|
| 438 |
+
layer_module.input_neuron2.v_th,
|
| 439 |
+
])
|
| 440 |
+
groups['rms_norms'].extend([
|
| 441 |
+
layer_module.block_norm.weight,
|
| 442 |
+
layer_module.ffn_norm.weight,
|
| 443 |
+
])
|
| 444 |
+
|
| 445 |
+
# 动态 K: 停止投影参数
|
| 446 |
+
groups['halt_projs'].extend(list(layer_module.block_halt.parameters()))
|
| 447 |
+
groups['halt_projs'].extend(list(layer_module.ffn_halt.parameters()))
|
| 448 |
+
|
| 449 |
+
# SNNBlock 参数
|
| 450 |
+
groups['W_in'].append(block.W_in.weight)
|
| 451 |
+
groups['W_beta'].extend([block.W_beta_x.weight])
|
| 452 |
+
groups['W_alpha'].extend([block.W_alpha_x.weight])
|
| 453 |
+
groups['W_th'].extend([block.W_th_x.weight])
|
| 454 |
+
groups['W_gate'].append(block.W_gate.weight)
|
| 455 |
+
groups['W_skip'].append(block.W_skip.weight)
|
| 456 |
+
groups['W_out'].append(block.W_out.weight)
|
| 457 |
+
groups['b_beta'].append(block.b_beta)
|
| 458 |
+
groups['b_alpha'].append(block.b_alpha)
|
| 459 |
+
groups['b_th'].append(block.b_th)
|
| 460 |
+
|
| 461 |
+
# SNNFFN 参数
|
| 462 |
+
groups['ffn_gate_proj'].append(ffn.gate_proj.weight)
|
| 463 |
+
groups['ffn_up_proj'].append(ffn.up_proj.weight)
|
| 464 |
+
groups['ffn_down_proj'].append(ffn.down_proj.weight)
|
| 465 |
+
groups['ffn_skip_proj'].append(ffn.skip_proj.weight)
|
| 466 |
+
groups['ffn_neurons'].extend([
|
| 467 |
+
ffn.gate_neuron.w, ffn.gate_neuron.v_th,
|
| 468 |
+
ffn.up_neuron.w, ffn.up_neuron.v_th,
|
| 469 |
+
])
|
| 470 |
+
|
| 471 |
+
return groups
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58fac5ad9b11e95d7601686846399b33f84593d29903d68c3dc38c05e4b93d16
|
| 3 |
+
size 3496634368
|
modeling_neuronspark.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NeuronSpark: SNN 隐状态空间语言模型 — HuggingFace 接口
|
| 3 |
+
|
| 4 |
+
用法:
|
| 5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 6 |
+
|
| 7 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 8 |
+
"checkpoints_sft/", trust_remote_code=True,
|
| 9 |
+
)
|
| 10 |
+
tokenizer = AutoTokenizer.from_pretrained("checkpoints_sft/")
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from transformers import PreTrainedModel, GenerationMixin
|
| 17 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 18 |
+
|
| 19 |
+
from configuration_neuronspark import NeuronSparkConfig
|
| 20 |
+
from model import SNNLanguageModel
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class NeuronSparkForCausalLM(PreTrainedModel, GenerationMixin):
|
| 24 |
+
"""
|
| 25 |
+
SNN 语言模型 — CausalLM 接口。
|
| 26 |
+
|
| 27 |
+
封装 SNNLanguageModel,提供 HuggingFace 标准接口:
|
| 28 |
+
- forward(input_ids, labels) → CausalLMOutputWithPast
|
| 29 |
+
- generate() 支持(通过 GenerationMixin)
|
| 30 |
+
"""
|
| 31 |
+
config_class = NeuronSparkConfig
|
| 32 |
+
supports_gradient_checkpointing = True
|
| 33 |
+
|
| 34 |
+
def __init__(self, config: NeuronSparkConfig):
|
| 35 |
+
super().__init__(config)
|
| 36 |
+
self.model = SNNLanguageModel(
|
| 37 |
+
vocab_size=config.vocab_size,
|
| 38 |
+
D=config.D,
|
| 39 |
+
N=config.N,
|
| 40 |
+
K=config.K,
|
| 41 |
+
num_layers=config.num_layers,
|
| 42 |
+
D_ff=config.D_ff,
|
| 43 |
+
v_th_min=config.v_th_min,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def get_input_embeddings(self):
|
| 47 |
+
return self.model.embed_tokens
|
| 48 |
+
|
| 49 |
+
def set_input_embeddings(self, value):
|
| 50 |
+
self.model.embed_tokens = value
|
| 51 |
+
|
| 52 |
+
def get_output_embeddings(self):
|
| 53 |
+
# tied head: 输出复用 embed_tokens.weight
|
| 54 |
+
return self.model.embed_tokens
|
| 55 |
+
|
| 56 |
+
def forward(
|
| 57 |
+
self,
|
| 58 |
+
input_ids: torch.Tensor,
|
| 59 |
+
labels: Optional[torch.Tensor] = None,
|
| 60 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 61 |
+
**kwargs,
|
| 62 |
+
) -> CausalLMOutputWithPast:
|
| 63 |
+
"""
|
| 64 |
+
前向传播。
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
input_ids: (batch, seq_len) token IDs
|
| 68 |
+
labels: (batch, seq_len) 目标 token IDs(可选,用于计算 loss)
|
| 69 |
+
attention_mask: 兼容参数(SNN 无 attention,忽略)
|
| 70 |
+
"""
|
| 71 |
+
if labels is not None:
|
| 72 |
+
out = self.model(input_ids, target_ids=labels)
|
| 73 |
+
# 计算 masked loss
|
| 74 |
+
loss_mask = (labels != 0).float().view(-1)
|
| 75 |
+
loss = (out.last_loss * loss_mask).sum() / loss_mask.sum()
|
| 76 |
+
# 加 ponder cost
|
| 77 |
+
if out.ponder_cost is not None:
|
| 78 |
+
loss = loss + 0.01 * out.ponder_cost
|
| 79 |
+
return CausalLMOutputWithPast(loss=loss)
|
| 80 |
+
else:
|
| 81 |
+
out = self.model(input_ids)
|
| 82 |
+
return CausalLMOutputWithPast(logits=out.logits)
|
| 83 |
+
|
| 84 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 85 |
+
"""generate() 所需的输入准备。"""
|
| 86 |
+
return {"input_ids": input_ids}
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
def generate(
|
| 90 |
+
self,
|
| 91 |
+
input_ids: torch.Tensor,
|
| 92 |
+
max_new_tokens: int = 256,
|
| 93 |
+
temperature: float = 1.0,
|
| 94 |
+
top_k: int = 50,
|
| 95 |
+
eos_token_id: Optional[int] = None,
|
| 96 |
+
**kwargs,
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
+
"""
|
| 99 |
+
自回归生成(直接调用 SNN 的 generate 方法)。
|
| 100 |
+
"""
|
| 101 |
+
return self.model.generate(
|
| 102 |
+
prompt_ids=input_ids,
|
| 103 |
+
max_new_tokens=max_new_tokens,
|
| 104 |
+
temperature=temperature,
|
| 105 |
+
top_k=top_k,
|
| 106 |
+
eos_token_id=eos_token_id,
|
| 107 |
+
)
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<|im_start|>",
|
| 3 |
+
"eos_token": "<|im_end|>",
|
| 4 |
+
"unk_token": "<unk>",
|
| 5 |
+
"pad_token": "<|im_end|>",
|
| 6 |
+
"additional_special_tokens": [
|
| 7 |
+
"<s>",
|
| 8 |
+
"</s>"
|
| 9 |
+
]
|
| 10 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": false,
|
| 5 |
+
"bos_token": "<|im_start|>",
|
| 6 |
+
"eos_token": "<|im_end|>",
|
| 7 |
+
"pad_token": "<|im_end|>",
|
| 8 |
+
"unk_token": "<unk>",
|
| 9 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 10 |
+
"clean_up_tokenization_spaces": false,
|
| 11 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 12 |
+
"chat_template": "{% for message in messages %}{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
| 13 |
+
}
|