Initial release: NeuronSpark-0.9B pretrained SNN language model
Browse files- LICENSE +190 -0
- README.md +117 -0
- atomic_ops/__init__.py +9 -0
- atomic_ops/fp16_codec.py +98 -0
- atomic_ops/lateral_inhibition.py +215 -0
- atomic_ops/parallel_scan.py +829 -0
- atomic_ops/plif_node.py +81 -0
- atomic_ops/rms_norm.py +36 -0
- atomic_ops/selective_plif.py +94 -0
- atomic_ops/snn_block.py +242 -0
- atomic_ops/snn_decoder_layer.py +327 -0
- atomic_ops/snn_ffn.py +185 -0
- config.json +21 -0
- configuration.json +1 -0
- configuration_neuronspark.py +38 -0
- model.py +471 -0
- model.safetensors +3 -0
- modeling_neuronspark.py +107 -0
- special_tokens_map.json +10 -0
- tokenizer.json +0 -0
- tokenizer_config.json +13 -0
LICENSE
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to the Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by the Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding any notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
Copyright 2025 Zhengzheng Tang
|
| 179 |
+
|
| 180 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 181 |
+
you may not use this file except in compliance with the License.
|
| 182 |
+
You may obtain a copy of the License at
|
| 183 |
+
|
| 184 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 185 |
+
|
| 186 |
+
Unless required by applicable law or agreed to in writing, software
|
| 187 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 188 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 189 |
+
See the License for the specific language governing permissions and
|
| 190 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- zh
|
| 5 |
+
library_name: transformers
|
| 6 |
+
tags:
|
| 7 |
+
- snn
|
| 8 |
+
- spiking-neural-network
|
| 9 |
+
- text-generation
|
| 10 |
+
- neuromorphic
|
| 11 |
+
pipeline_tag: text-generation
|
| 12 |
+
---
|
| 13 |
+
# NeuronSpark-0.9B
|
| 14 |
+
|
| 15 |
+
## Introduction
|
| 16 |
+
|
| 17 |
+
**NeuronSpark-0.9B** is a **0.87-billion parameter language model built entirely on Spiking Neural Networks (SNNs)**. Unlike conventional Transformer-based LLMs that rely on attention mechanisms, NeuronSpark replaces the entire computation backbone with biologically-inspired spiking neurons, achieving language modeling through membrane potential dynamics, surrogate gradient training, and adaptive computation (PonderNet).
|
| 18 |
+
|
| 19 |
+
This is the **pretrained base model** (85,000 steps on a small subset of Seq-Monkey corpus).
|
| 20 |
+
|
| 21 |
+
> **Note on training data**: Due to limited compute resources (single DGX Spark), this model was trained on only **~85K steps with a small fraction of the full Seq-Monkey 10B-token corpus**. Despite the minimal training data, the model demonstrates emergent language capabilities — validating the architectural viability of pure SNN language models. We plan to continue scaling with more data and compute in future work.
|
| 22 |
+
|
| 23 |
+
For the instruction-tuned chat version, see [NeuronSpark-0.9B-Chat](https://modelscope.cn/models/Brain2nd/NeuronSpark-0.9B-Chat).
|
| 24 |
+
|
| 25 |
+
## Model Details
|
| 26 |
+
|
| 27 |
+
| Attribute | Value |
|
| 28 |
+
|-----------|-------|
|
| 29 |
+
| Parameters | 874M |
|
| 30 |
+
| Architecture | SNN Hidden State Space Model |
|
| 31 |
+
| Hidden Dimension (D) | 896 |
|
| 32 |
+
| Layers | 20 |
|
| 33 |
+
| SNN Timesteps (K) | 16 (PonderNet adaptive) |
|
| 34 |
+
| State Expansion (N) | 8 |
|
| 35 |
+
| FFN Dimension | 2688 |
|
| 36 |
+
| Vocabulary | 6144 (custom BPE) |
|
| 37 |
+
| Context Length | 512 tokens |
|
| 38 |
+
| Training Data | Seq-Monkey (small subset, Chinese) |
|
| 39 |
+
| Training Tokens | ~1.4B (of ~10B available) |
|
| 40 |
+
| Precision | bfloat16 |
|
| 41 |
+
| License | Apache 2.0 |
|
| 42 |
+
|
| 43 |
+
## Architecture Highlights
|
| 44 |
+
|
| 45 |
+
- **Pure SNN**: No attention, no standard MLP — all computation via PLIF (Parametric Leaky Integrate-and-Fire) neurons
|
| 46 |
+
- **Membrane Potential Leakage Activation**: PLIFNode outputs `(1-β)·V_post` (leak current), naturally emphasizing fast-responding neurons over slow-memory neurons
|
| 47 |
+
- **Selective State Space**: Hidden neurons with input-dependent dynamic β(t), α(t), V_th(t) — analogous to selective state space models (Mamba)
|
| 48 |
+
- **PonderNet Adaptive K**: Each token dynamically decides how many SNN timesteps to use (1~K), with geometric distribution weighting
|
| 49 |
+
- **Triton Fused Kernels**: Custom PLIF forward/backward kernels, single-pass sequential scan replacing 3-phase approach
|
| 50 |
+
- **Pre-LN Residual Stream**: Continuous residual flow with RMSNorm, matching Qwen3/LLaMA architecture pattern
|
| 51 |
+
|
| 52 |
+
## Quickstart
|
| 53 |
+
|
| 54 |
+
```python
|
| 55 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 56 |
+
|
| 57 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 58 |
+
"Brain2nd/NeuronSpark-0.9B",
|
| 59 |
+
trust_remote_code=True,
|
| 60 |
+
)
|
| 61 |
+
tokenizer = AutoTokenizer.from_pretrained("Brain2nd/NeuronSpark-0.9B")
|
| 62 |
+
|
| 63 |
+
# Text completion
|
| 64 |
+
text = f"{tokenizer.bos_token}人工智能的发展"
|
| 65 |
+
input_ids = tokenizer(text, return_tensors="pt")["input_ids"]
|
| 66 |
+
|
| 67 |
+
output_ids = model.generate(
|
| 68 |
+
input_ids,
|
| 69 |
+
max_new_tokens=128,
|
| 70 |
+
temperature=0.8,
|
| 71 |
+
top_k=50,
|
| 72 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 73 |
+
)
|
| 74 |
+
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
**Example Output:**
|
| 78 |
+
```
|
| 79 |
+
人工智能的发展,为人类的未来发展提供了新的机遇。在未来,人工智能将是未来人工智能发展的重要方向。
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## Requirements
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
pip install torch transformers spikingjelly safetensors
|
| 86 |
+
# For Triton kernels (GPU): pip install triton
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
## Training
|
| 90 |
+
|
| 91 |
+
Trained on a single NVIDIA DGX Spark (GB10, 128GB unified memory) with 4-GPU DDP.
|
| 92 |
+
Due to compute constraints, training used only a small subset of the full corpus (~85K steps, ~1.4B tokens of ~10B available). Even with this limited data budget, the model acquires basic language generation ability, demonstrating the architectural viability of pure SNN language modeling.
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
torchrun --nproc_per_node=4 train_ddp.py \
|
| 96 |
+
--D 896 --D_ff 2688 --K 16 --num_layers 20 \
|
| 97 |
+
--batch_size 8 --accumulation_steps 8 \
|
| 98 |
+
--learning_rate 2e-4 --warmup_iters 1000
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## Citation
|
| 102 |
+
|
| 103 |
+
```bibtex
|
| 104 |
+
@misc{neuronspark2025,
|
| 105 |
+
title={NeuronSpark: A Spiking Neural Network Language Model with Selective State Space Dynamics},
|
| 106 |
+
author={Zhengzheng Tang},
|
| 107 |
+
year={2025},
|
| 108 |
+
url={https://github.com/Brain2nd/NeuronSpark}
|
| 109 |
+
}
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## Contact
|
| 113 |
+
|
| 114 |
+
- **Author**: Zhengzheng Tang
|
| 115 |
+
- **Email**: zztangbu@bu.edu
|
| 116 |
+
- **GitHub**: [Brain2nd/NeuronSpark](https://github.com/Brain2nd/NeuronSpark)
|
| 117 |
+
|
atomic_ops/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .selective_plif import SelectivePLIFNode
|
| 2 |
+
from .plif_node import PLIFNode
|
| 3 |
+
from .lateral_inhibition import LateralInhibition
|
| 4 |
+
from .snn_block import SNNBlock
|
| 5 |
+
from .snn_ffn import SNNFFN
|
| 6 |
+
from .snn_decoder_layer import SNNDecoderLayer
|
| 7 |
+
from .parallel_scan import hillis_steele_scan, linear_recurrence, plif_parallel_forward
|
| 8 |
+
from .fp16_codec import fp16_encode, fp16_decode
|
| 9 |
+
from .rms_norm import RMSNorm
|
atomic_ops/fp16_codec.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FP16 二进制编码/解码 — 模型边界操作(无可训练参数)。
|
| 3 |
+
|
| 4 |
+
IEEE 754 float16 位布局(K=16 时间步):
|
| 5 |
+
时间步: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
|
| 6 |
+
位: sign E4 E3 E2 E1 E0 M9 M8 M7 M6 M5 M4 M3 M2 M1 M0
|
| 7 |
+
含义: 符号 ←── 指数(bias=15) ──→ ←────────── 尾数(隐含 1.xxx) ──────────→
|
| 8 |
+
|
| 9 |
+
编码: embedding → IEEE 754 float16 位提取 → 16 帧二值 spike(detach,固定预处理)
|
| 10 |
+
解码: 16 帧二值 spike → IEEE 754 位重建 → 连续值(可微分,梯度通过 surrogate grad 传播)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def fp16_encode(emb: Tensor, K: int = 16) -> Tensor:
|
| 18 |
+
"""FP16 二进制编码(模型边界操作,固定预处理)。
|
| 19 |
+
|
| 20 |
+
将连续 embedding 转为 IEEE 754 float16 位模式,作为 SNN 的 spike 输入。
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
emb: (batch, seq_len, D) 连续 embedding
|
| 24 |
+
K: 时间步数(必须为 16,对应 float16 的 16 位)
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
spike_seq: (seq_len*K, batch, D) 二值 {0, 1}, detached
|
| 28 |
+
"""
|
| 29 |
+
batch, seq_len, D = emb.shape
|
| 30 |
+
|
| 31 |
+
# 转为 float16 获取 IEEE 754 位模式
|
| 32 |
+
# clamp 防止 overflow 产生 Inf(float16 最大值 65504)
|
| 33 |
+
emb_fp16 = emb.float().clamp(-65504.0, 65504.0).half()
|
| 34 |
+
bits_int = emb_fp16.view(torch.int16) # (batch, seq_len, D)
|
| 35 |
+
|
| 36 |
+
# 提取 16 位(MSB first: sign, exponent, mantissa)
|
| 37 |
+
shifts = torch.arange(15, -1, -1, device=emb.device) # [15, 14, ..., 0]
|
| 38 |
+
# bits_int: (batch, seq_len, D) → unsqueeze → (batch, seq_len, 1, D)
|
| 39 |
+
# shifts: (K,) → view → (1, 1, K, 1)
|
| 40 |
+
bits = ((bits_int.unsqueeze(2) >> shifts.view(1, 1, K, 1)) & 1) # (batch, seq_len, K, D)
|
| 41 |
+
|
| 42 |
+
# 转为计算 dtype 并 detach(编码不参与梯度)
|
| 43 |
+
bits = bits.to(emb.dtype).detach()
|
| 44 |
+
|
| 45 |
+
# reshape → (seq_len*K, batch, D)
|
| 46 |
+
return bits.reshape(batch, seq_len * K, D).permute(1, 0, 2).contiguous()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def fp16_decode(spikes: Tensor, seq_len: int, K: int = 16) -> Tensor:
|
| 50 |
+
"""FP16 精确位解码:从 16 个二值 spike 重建 float16 值。
|
| 51 |
+
|
| 52 |
+
fp16_encode 的精确逆操作。全程可微分——梯度通过 IEEE 754 重建公式
|
| 53 |
+
传到每个 spike 输出,再经 surrogate gradient 传入 SNN。
|
| 54 |
+
|
| 55 |
+
IEEE 754 float16 重建:
|
| 56 |
+
Normal (exp > 0): (-1)^sign * 2^(exp - 15) * (1 + mant_frac)
|
| 57 |
+
Subnormal (exp = 0): (-1)^sign * 2^(-14) * mant_frac
|
| 58 |
+
其中 mant_frac = Σ mant_bit_i * 2^{-(i+1)}, i=0..9
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
spikes: (seq_len*K, batch, D) 二值 {0, 1}(输出神经元的 spike)
|
| 62 |
+
seq_len: token 序列长度
|
| 63 |
+
K: 时间步数(= 16)
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
decoded: (batch, seq_len, D) 连续值
|
| 67 |
+
"""
|
| 68 |
+
batch, D = spikes.shape[1], spikes.shape[2]
|
| 69 |
+
|
| 70 |
+
# (seq_len*K, batch, D) → (batch, seq_len, K, D)
|
| 71 |
+
s = spikes.permute(1, 0, 2).reshape(batch, seq_len, K, D)
|
| 72 |
+
|
| 73 |
+
# ---- Sign: bit 0 ----
|
| 74 |
+
sign = 1.0 - 2.0 * s[:, :, 0, :] # +1 or -1
|
| 75 |
+
|
| 76 |
+
# ---- Exponent: bits 1-5, 加权求和 → 整数 0~31 ----
|
| 77 |
+
exp_weights = torch.tensor(
|
| 78 |
+
[16.0, 8.0, 4.0, 2.0, 1.0],
|
| 79 |
+
device=spikes.device, dtype=spikes.dtype,
|
| 80 |
+
)
|
| 81 |
+
exp_val = (s[:, :, 1:6, :] * exp_weights.view(1, 1, 5, 1)).sum(dim=2)
|
| 82 |
+
|
| 83 |
+
# ---- Mantissa fraction: bits 6-15, 加权求和 → [0, 1) ----
|
| 84 |
+
mant_weights = torch.tensor(
|
| 85 |
+
[2.0 ** (-i) for i in range(1, 11)],
|
| 86 |
+
device=spikes.device, dtype=spikes.dtype,
|
| 87 |
+
)
|
| 88 |
+
mant_frac = (s[:, :, 6:, :] * mant_weights.view(1, 1, 10, 1)).sum(dim=2)
|
| 89 |
+
|
| 90 |
+
# ---- IEEE 754 重建 ----
|
| 91 |
+
# Normal: (-1)^s * 2^(exp-15) * (1 + mant_frac)
|
| 92 |
+
# Subnormal: (-1)^s * 2^(-14) * mant_frac
|
| 93 |
+
is_normal = (exp_val > 0)
|
| 94 |
+
|
| 95 |
+
normal_val = sign * torch.exp2(exp_val - 15.0) * (1.0 + mant_frac)
|
| 96 |
+
subnormal_val = sign * (2.0 ** -14) * mant_frac
|
| 97 |
+
|
| 98 |
+
return torch.where(is_normal, normal_val, subnormal_val)
|
atomic_ops/lateral_inhibition.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LateralInhibition: 侧抑制归一化(Divisive Normalization)
|
| 3 |
+
|
| 4 |
+
神经科学基础:
|
| 5 |
+
Carandini & Heeger (2012) "Normalization as a canonical neural computation"
|
| 6 |
+
侧抑制是大脑中最基本的计算原语之一:兴奋性神经元的活动通过抑制性中间神经元池
|
| 7 |
+
反馈调节,实现增益控制(gain control)。
|
| 8 |
+
|
| 9 |
+
SNN 机制:
|
| 10 |
+
1. 兴奋性群体活动: activity_i = h_i²
|
| 11 |
+
2. 抑制性中间神经元池: pool = mean(activity) = mean(h²)
|
| 12 |
+
3. 分裂抑制 (shunting inhibition): h_norm = h / sqrt(pool + ε)
|
| 13 |
+
4. 增益调制 (gain modulation): output = gain · h_norm
|
| 14 |
+
|
| 15 |
+
替换 RMSNorm:数学操作等价,但在 SNN 框架中有明确的神经科学对应——
|
| 16 |
+
RMSNorm 是 divisive normalization 的特例。
|
| 17 |
+
|
| 18 |
+
Triton fused kernel:
|
| 19 |
+
- 前向: {mean(h²), rsqrt, element-wise mul} → 1 kernel launch
|
| 20 |
+
- 反向: {recompute norm, grad_gain, grad_h} → 1 kernel launch
|
| 21 |
+
- 每行 (D dim) 一个 block,行间并行
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from spikingjelly.activation_based import base
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ============================================================
|
| 32 |
+
# Triton fused kernels
|
| 33 |
+
# ============================================================
|
| 34 |
+
|
| 35 |
+
_SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas'
|
| 36 |
+
if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ:
|
| 37 |
+
os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS
|
| 38 |
+
|
| 39 |
+
_HAS_TRITON = False
|
| 40 |
+
try:
|
| 41 |
+
import triton
|
| 42 |
+
import triton.language as tl
|
| 43 |
+
_HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if _HAS_TRITON:
|
| 49 |
+
|
| 50 |
+
@triton.jit
|
| 51 |
+
def _li_fwd_kernel(
|
| 52 |
+
X_ptr, GAIN_ptr, OUT_ptr,
|
| 53 |
+
stride_row,
|
| 54 |
+
D: tl.constexpr,
|
| 55 |
+
eps: tl.constexpr,
|
| 56 |
+
BLOCK_D: tl.constexpr,
|
| 57 |
+
):
|
| 58 |
+
"""Forward: out = x * rsqrt(mean(x²) + eps) * gain
|
| 59 |
+
|
| 60 |
+
Grid: (num_rows,). Each program processes one row of D elements.
|
| 61 |
+
Computation in float32; storage in input dtype.
|
| 62 |
+
"""
|
| 63 |
+
row = tl.program_id(0)
|
| 64 |
+
cols = tl.arange(0, BLOCK_D)
|
| 65 |
+
mask = cols < D
|
| 66 |
+
off = row * stride_row + cols
|
| 67 |
+
|
| 68 |
+
# Load in float32
|
| 69 |
+
x = tl.load(X_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 70 |
+
gain = tl.load(GAIN_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 71 |
+
|
| 72 |
+
# Inhibitory pool: population activity
|
| 73 |
+
variance = tl.sum(x * x, axis=0) / D
|
| 74 |
+
rrms = 1.0 / tl.sqrt(variance + eps)
|
| 75 |
+
|
| 76 |
+
# Divisive inhibition + gain modulation
|
| 77 |
+
out = x * rrms * gain
|
| 78 |
+
|
| 79 |
+
tl.store(OUT_ptr + off, out, mask=mask)
|
| 80 |
+
|
| 81 |
+
@triton.jit
|
| 82 |
+
def _li_bwd_kernel(
|
| 83 |
+
DOUT_ptr, X_ptr, GAIN_ptr,
|
| 84 |
+
DX_ptr, DGAIN_ptr,
|
| 85 |
+
stride_row,
|
| 86 |
+
D: tl.constexpr,
|
| 87 |
+
eps: tl.constexpr,
|
| 88 |
+
BLOCK_D: tl.constexpr,
|
| 89 |
+
):
|
| 90 |
+
"""Backward: grad_x, grad_gain (per-row, reduced externally).
|
| 91 |
+
|
| 92 |
+
Grid: (num_rows,).
|
| 93 |
+
d_x = rrms * (d_out * gain - x_hat * mean(d_out * gain * x_hat))
|
| 94 |
+
d_gain_row = d_out * x_hat (sum across rows done outside kernel)
|
| 95 |
+
"""
|
| 96 |
+
row = tl.program_id(0)
|
| 97 |
+
cols = tl.arange(0, BLOCK_D)
|
| 98 |
+
mask = cols < D
|
| 99 |
+
off = row * stride_row + cols
|
| 100 |
+
|
| 101 |
+
dout = tl.load(DOUT_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 102 |
+
x = tl.load(X_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 103 |
+
gain = tl.load(GAIN_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 104 |
+
|
| 105 |
+
# Recompute forward (avoid saving intermediate tensors)
|
| 106 |
+
variance = tl.sum(x * x, axis=0) / D
|
| 107 |
+
rrms = 1.0 / tl.sqrt(variance + eps)
|
| 108 |
+
x_hat = x * rrms
|
| 109 |
+
|
| 110 |
+
# grad_gain (per-row contribution)
|
| 111 |
+
dgain = dout * x_hat
|
| 112 |
+
tl.store(DGAIN_ptr + off, dgain, mask=mask)
|
| 113 |
+
|
| 114 |
+
# grad_x: rrms * (dout*gain - x_hat * mean(dout*gain*x_hat))
|
| 115 |
+
dout_gain = dout * gain
|
| 116 |
+
dot = tl.sum(dout_gain * x_hat, axis=0) / D
|
| 117 |
+
dx = (dout_gain - x_hat * dot) * rrms
|
| 118 |
+
|
| 119 |
+
tl.store(DX_ptr + off, dx, mask=mask)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class _LateralInhibitionTriton(torch.autograd.Function):
|
| 123 |
+
"""Triton-accelerated lateral inhibition (divisive normalization)."""
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def forward(ctx, x, gain, eps):
|
| 127 |
+
orig_shape = x.shape
|
| 128 |
+
D = x.shape[-1]
|
| 129 |
+
x_2d = x.reshape(-1, D).contiguous()
|
| 130 |
+
N = x_2d.shape[0]
|
| 131 |
+
|
| 132 |
+
out = torch.empty_like(x_2d)
|
| 133 |
+
|
| 134 |
+
BLOCK_D = triton.next_power_of_2(D)
|
| 135 |
+
_li_fwd_kernel[(N,)](
|
| 136 |
+
x_2d, gain, out,
|
| 137 |
+
x_2d.stride(0),
|
| 138 |
+
D=D, eps=eps, BLOCK_D=BLOCK_D,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
ctx.save_for_backward(x_2d, gain)
|
| 142 |
+
ctx.eps = eps
|
| 143 |
+
ctx.orig_shape = orig_shape
|
| 144 |
+
ctx.N = N
|
| 145 |
+
ctx.D = D
|
| 146 |
+
|
| 147 |
+
return out.reshape(orig_shape)
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def backward(ctx, grad_output):
|
| 151 |
+
x_2d, gain = ctx.saved_tensors
|
| 152 |
+
D = ctx.D
|
| 153 |
+
N = ctx.N
|
| 154 |
+
|
| 155 |
+
grad_2d = grad_output.reshape(N, D).contiguous()
|
| 156 |
+
|
| 157 |
+
dx = torch.empty_like(x_2d)
|
| 158 |
+
dgain_rows = torch.empty_like(x_2d)
|
| 159 |
+
|
| 160 |
+
BLOCK_D = triton.next_power_of_2(D)
|
| 161 |
+
_li_bwd_kernel[(N,)](
|
| 162 |
+
grad_2d, x_2d, gain,
|
| 163 |
+
dx, dgain_rows,
|
| 164 |
+
x_2d.stride(0),
|
| 165 |
+
D=D, eps=ctx.eps, BLOCK_D=BLOCK_D,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Reduce per-row dgain across all rows
|
| 169 |
+
dgain = dgain_rows.sum(dim=0)
|
| 170 |
+
|
| 171 |
+
return dx.reshape(ctx.orig_shape), dgain, None
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ============================================================
|
| 175 |
+
# Public module
|
| 176 |
+
# ============================================================
|
| 177 |
+
|
| 178 |
+
class LateralInhibition(base.MemoryModule):
|
| 179 |
+
"""
|
| 180 |
+
侧抑制归一化层(Divisive Normalization)。
|
| 181 |
+
|
| 182 |
+
通过抑制性中间神经元池实现增益控制。
|
| 183 |
+
|
| 184 |
+
数学:
|
| 185 |
+
pool = mean(h², dim=-1) # 抑制性池:群体活动水平
|
| 186 |
+
h_norm = h / sqrt(pool + ε) # 分裂抑制 (shunting inhibition)
|
| 187 |
+
output = gain · h_norm # 增益调制 (gain modulation)
|
| 188 |
+
|
| 189 |
+
等价于 RMSNorm,但在 SNN 框架中对应 divisive normalization
|
| 190 |
+
(Carandini & Heeger, 2012),是神经科学中最基本的计算原语之一。
|
| 191 |
+
|
| 192 |
+
CUDA: Triton fused kernel(前向+反向各 1 次 launch)
|
| 193 |
+
CPU: PyTorch fallback
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
dim: 特征维度(D)
|
| 197 |
+
eps: 数值稳定性
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.gain = nn.Parameter(torch.ones(dim))
|
| 203 |
+
self.eps = eps
|
| 204 |
+
self.dim = dim
|
| 205 |
+
|
| 206 |
+
def forward(self, h: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
if _HAS_TRITON and h.is_cuda:
|
| 208 |
+
return _LateralInhibitionTriton.apply(h, self.gain, self.eps)
|
| 209 |
+
# PyTorch fallback
|
| 210 |
+
variance = h.pow(2).mean(-1, keepdim=True)
|
| 211 |
+
h_norm = h * torch.rsqrt(variance + self.eps)
|
| 212 |
+
return self.gain * h_norm
|
| 213 |
+
|
| 214 |
+
def extra_repr(self):
|
| 215 |
+
return f'dim={self.dim}, eps={self.eps}'
|
atomic_ops/parallel_scan.py
ADDED
|
@@ -0,0 +1,829 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Parallel Scan 工具函数:SNN 线性递推的高效并行求解
|
| 3 |
+
|
| 4 |
+
实现三层后端:
|
| 5 |
+
1. Fused PLIF kernel(默认,CUDA + Sigmoid surrogate):
|
| 6 |
+
单 kernel 完成 PLIF 前向(scan + spike + soft reset)和反向(surrogate gradient)
|
| 7 |
+
· per-element beta/v_th: _fused_plif_fwd_kernel / _fused_plif_bwd_kernel
|
| 8 |
+
· row-param beta/v_th: _fused_plif_fwd_rowparam_kernel / _fused_plif_bwd_rowparam_kernel
|
| 9 |
+
2. Triton linear_recurrence(CUDA,非 Sigmoid 或无 surrogate):
|
| 10 |
+
列级并行 scan,O(K) 工作量,1 次 kernel launch
|
| 11 |
+
3. Hillis-Steele parallel scan(CPU 回退):O(K log K) 工作量
|
| 12 |
+
|
| 13 |
+
线性递推:
|
| 14 |
+
V[k] = a[k] * V[k-1] + b[k], V[-1] = v_init
|
| 15 |
+
|
| 16 |
+
PLIF 神经元动力学:
|
| 17 |
+
V_pre[k] = beta[k] * V_post[k-1] + u[k]
|
| 18 |
+
s[k] = Θ(V_pre[k] - v_th[k])
|
| 19 |
+
V_post[k] = V_pre[k] - v_th[k] * s[k]
|
| 20 |
+
|
| 21 |
+
数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ============================================================
|
| 29 |
+
# Triton fused recurrence kernels
|
| 30 |
+
# ============================================================
|
| 31 |
+
|
| 32 |
+
# DGX Spark (GB10, sm_121a): Triton 3.5.1 自带 ptxas 不支持 sm_121a,
|
| 33 |
+
# 需要使用系统 CUDA 13.0 的 ptxas
|
| 34 |
+
_SYSTEM_PTXAS = '/usr/local/cuda-13.0/bin/ptxas'
|
| 35 |
+
if os.path.exists(_SYSTEM_PTXAS) and 'TRITON_PTXAS_PATH' not in os.environ:
|
| 36 |
+
os.environ['TRITON_PTXAS_PATH'] = _SYSTEM_PTXAS
|
| 37 |
+
|
| 38 |
+
_HAS_TRITON = False
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
_HAS_TRITON = True
|
| 43 |
+
except ImportError:
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
if _HAS_TRITON:
|
| 47 |
+
|
| 48 |
+
@triton.jit
|
| 49 |
+
def _fwd_recurrence_kernel(
|
| 50 |
+
A_ptr, B_ptr, INIT_ptr, OUT_ptr,
|
| 51 |
+
K, num_cols,
|
| 52 |
+
BLOCK: tl.constexpr,
|
| 53 |
+
):
|
| 54 |
+
"""Forward: V[k] = A[k]*V[k-1] + B[k], V[-1] = init.
|
| 55 |
+
|
| 56 |
+
Grid: (ceil(num_cols / BLOCK),)
|
| 57 |
+
Each program processes BLOCK columns across all K sequential steps.
|
| 58 |
+
Accumulation in float32; storage in input dtype.
|
| 59 |
+
"""
|
| 60 |
+
pid = tl.program_id(0)
|
| 61 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 62 |
+
mask = cols < num_cols
|
| 63 |
+
|
| 64 |
+
v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 65 |
+
|
| 66 |
+
for k in range(K):
|
| 67 |
+
off = k * num_cols + cols
|
| 68 |
+
a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 69 |
+
b = tl.load(B_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 70 |
+
v = a * v + b
|
| 71 |
+
tl.store(OUT_ptr + off, v, mask=mask)
|
| 72 |
+
|
| 73 |
+
@triton.jit
|
| 74 |
+
def _bwd_recurrence_kernel(
|
| 75 |
+
A_ptr, V_ptr, INIT_ptr, GRAD_OUT_ptr,
|
| 76 |
+
GRAD_A_ptr, GRAD_B_ptr, GRAD_INIT_ptr,
|
| 77 |
+
K, num_cols,
|
| 78 |
+
BLOCK: tl.constexpr,
|
| 79 |
+
):
|
| 80 |
+
"""Backward for V[k] = A[k]*V[k-1] + B[k].
|
| 81 |
+
|
| 82 |
+
Reverse accumulation (k from K-1 down to 0):
|
| 83 |
+
g = 0
|
| 84 |
+
for k = K-1, ..., 0:
|
| 85 |
+
g += grad_out[k]
|
| 86 |
+
grad_B[k] = g
|
| 87 |
+
grad_A[k] = g * V[k-1] (V[-1] = init)
|
| 88 |
+
g = g * A[k]
|
| 89 |
+
grad_init = g
|
| 90 |
+
"""
|
| 91 |
+
pid = tl.program_id(0)
|
| 92 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 93 |
+
mask = cols < num_cols
|
| 94 |
+
|
| 95 |
+
g = tl.zeros([BLOCK], dtype=tl.float32)
|
| 96 |
+
|
| 97 |
+
for k_rev in range(K):
|
| 98 |
+
k = K - 1 - k_rev
|
| 99 |
+
off = k * num_cols + cols
|
| 100 |
+
|
| 101 |
+
dV = tl.load(GRAD_OUT_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 102 |
+
g = g + dV
|
| 103 |
+
|
| 104 |
+
tl.store(GRAD_B_ptr + off, g, mask=mask)
|
| 105 |
+
|
| 106 |
+
if k > 0:
|
| 107 |
+
v_prev = tl.load(
|
| 108 |
+
V_ptr + (k - 1) * num_cols + cols,
|
| 109 |
+
mask=mask, other=0.0,
|
| 110 |
+
).to(tl.float32)
|
| 111 |
+
else:
|
| 112 |
+
v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 113 |
+
tl.store(GRAD_A_ptr + off, g * v_prev, mask=mask)
|
| 114 |
+
|
| 115 |
+
a = tl.load(A_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 116 |
+
g = g * a
|
| 117 |
+
|
| 118 |
+
tl.store(GRAD_INIT_ptr + cols, g, mask=mask)
|
| 119 |
+
|
| 120 |
+
class _TritonLinearRecurrence(torch.autograd.Function):
|
| 121 |
+
"""Fused Triton linear recurrence: V[k] = A[k]*V[k-1] + B[k]."""
|
| 122 |
+
|
| 123 |
+
_BLOCK = 128
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def forward(ctx, beta, u, v_init):
|
| 127 |
+
beta_c = beta.contiguous()
|
| 128 |
+
u_c = u.contiguous()
|
| 129 |
+
v_init_c = v_init.contiguous()
|
| 130 |
+
|
| 131 |
+
K = beta_c.shape[0]
|
| 132 |
+
num_cols = beta_c[0].numel()
|
| 133 |
+
V = torch.empty_like(u_c)
|
| 134 |
+
|
| 135 |
+
BLOCK = _TritonLinearRecurrence._BLOCK
|
| 136 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 137 |
+
|
| 138 |
+
_fwd_recurrence_kernel[grid](
|
| 139 |
+
beta_c, u_c, v_init_c, V,
|
| 140 |
+
K, num_cols,
|
| 141 |
+
BLOCK=BLOCK,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
|
| 145 |
+
ctx.save_for_backward(beta_c, V, v_init_c)
|
| 146 |
+
ctx.K = K
|
| 147 |
+
ctx.num_cols = num_cols
|
| 148 |
+
|
| 149 |
+
return V
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def backward(ctx, grad_V):
|
| 153 |
+
beta, V, v_init = ctx.saved_tensors
|
| 154 |
+
grad_V_c = grad_V.contiguous()
|
| 155 |
+
|
| 156 |
+
K = ctx.K
|
| 157 |
+
num_cols = ctx.num_cols
|
| 158 |
+
|
| 159 |
+
grad_beta = torch.empty_like(beta)
|
| 160 |
+
grad_u = torch.empty_like(beta)
|
| 161 |
+
grad_v_init = torch.empty_like(v_init)
|
| 162 |
+
|
| 163 |
+
BLOCK = _TritonLinearRecurrence._BLOCK
|
| 164 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 165 |
+
|
| 166 |
+
_bwd_recurrence_kernel[grid](
|
| 167 |
+
beta, V, v_init, grad_V_c,
|
| 168 |
+
grad_beta, grad_u, grad_v_init,
|
| 169 |
+
K, num_cols,
|
| 170 |
+
BLOCK=BLOCK,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return grad_beta, grad_u, grad_v_init
|
| 174 |
+
|
| 175 |
+
# ============================================================
|
| 176 |
+
# Fused PLIF forward/backward kernels
|
| 177 |
+
# ============================================================
|
| 178 |
+
|
| 179 |
+
@triton.jit
|
| 180 |
+
def _fused_plif_fwd_kernel(
|
| 181 |
+
BETA_ptr, U_ptr, VTH_ptr, INIT_ptr,
|
| 182 |
+
SPIKE_ptr, VPOST_ptr,
|
| 183 |
+
K, num_cols,
|
| 184 |
+
BLOCK: tl.constexpr,
|
| 185 |
+
):
|
| 186 |
+
"""Fused PLIF forward: single-pass sequential scan with inline spike + soft reset.
|
| 187 |
+
|
| 188 |
+
Exact computation — sequential scan IS the ground truth.
|
| 189 |
+
Replaces the 3-phase approach (linear scan + spike iteration + correction).
|
| 190 |
+
|
| 191 |
+
Per column (parallel across batch*D):
|
| 192 |
+
v = v_init
|
| 193 |
+
for k = 0..K-1:
|
| 194 |
+
v_pre = beta[k]*v + u[k]
|
| 195 |
+
spike[k] = Θ(v_pre - v_th[k])
|
| 196 |
+
v = v_pre - v_th[k]*spike[k]
|
| 197 |
+
"""
|
| 198 |
+
pid = tl.program_id(0)
|
| 199 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 200 |
+
mask = cols < num_cols
|
| 201 |
+
|
| 202 |
+
v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 203 |
+
|
| 204 |
+
for k in range(K):
|
| 205 |
+
off = k * num_cols + cols
|
| 206 |
+
beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 207 |
+
u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 208 |
+
vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 209 |
+
|
| 210 |
+
v_pre = beta * v + u
|
| 211 |
+
spike = tl.where(v_pre >= vth, 1.0, 0.0)
|
| 212 |
+
v = v_pre - vth * spike # soft reset
|
| 213 |
+
|
| 214 |
+
tl.store(SPIKE_ptr + off, spike, mask=mask)
|
| 215 |
+
tl.store(VPOST_ptr + off, v, mask=mask)
|
| 216 |
+
|
| 217 |
+
@triton.jit
|
| 218 |
+
def _fused_plif_bwd_kernel(
|
| 219 |
+
BETA_ptr, VTH_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr,
|
| 220 |
+
GRAD_SPIKE_ptr, GRAD_VPOST_ptr,
|
| 221 |
+
GRAD_BETA_ptr, GRAD_U_ptr, GRAD_VTH_ptr, GRAD_INIT_ptr,
|
| 222 |
+
K, num_cols, ALPHA,
|
| 223 |
+
BLOCK: tl.constexpr,
|
| 224 |
+
):
|
| 225 |
+
"""Fused PLIF backward: single reverse pass with Sigmoid surrogate gradient.
|
| 226 |
+
|
| 227 |
+
V_pre[k] = V_post[k] + v_th[k]*spike[k] (reconstructed)
|
| 228 |
+
surrogate_grad(x) = alpha * sigmoid(alpha*x) * (1 - sigmoid(alpha*x))
|
| 229 |
+
where x = V_pre[k] - v_th[k] = V_post[k] - v_th[k]*(1 - spike[k])
|
| 230 |
+
|
| 231 |
+
Reverse accumulation:
|
| 232 |
+
acc = 0
|
| 233 |
+
for k = K-1 downto 0:
|
| 234 |
+
total_gV = grad_V_post[k] + acc
|
| 235 |
+
sg = surrogate_grad(V_pre[k] - v_th[k])
|
| 236 |
+
grad_v_pre = grad_spike[k]*sg + total_gV
|
| 237 |
+
grad_beta[k] = grad_v_pre * V_post[k-1]
|
| 238 |
+
grad_u[k] = grad_v_pre
|
| 239 |
+
grad_v_th[k] = -grad_spike[k]*sg - total_gV*spike[k]
|
| 240 |
+
acc = grad_v_pre * beta[k]
|
| 241 |
+
grad_v_init = acc
|
| 242 |
+
"""
|
| 243 |
+
pid = tl.program_id(0)
|
| 244 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 245 |
+
mask = cols < num_cols
|
| 246 |
+
|
| 247 |
+
acc = tl.zeros([BLOCK], dtype=tl.float32)
|
| 248 |
+
|
| 249 |
+
for k_rev in range(K):
|
| 250 |
+
k = K - 1 - k_rev
|
| 251 |
+
off = k * num_cols + cols
|
| 252 |
+
|
| 253 |
+
beta = tl.load(BETA_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 254 |
+
vth = tl.load(VTH_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 255 |
+
v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 256 |
+
spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 257 |
+
|
| 258 |
+
g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 259 |
+
g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 260 |
+
|
| 261 |
+
# V_post[k-1]
|
| 262 |
+
if k > 0:
|
| 263 |
+
v_prev = tl.load(
|
| 264 |
+
VPOST_ptr + (k - 1) * num_cols + cols,
|
| 265 |
+
mask=mask, other=0.0,
|
| 266 |
+
).to(tl.float32)
|
| 267 |
+
else:
|
| 268 |
+
v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 269 |
+
|
| 270 |
+
# Sigmoid surrogate gradient
|
| 271 |
+
x = v_post - vth * (1.0 - spike) # = V_pre - v_th
|
| 272 |
+
neg_ax = -ALPHA * x
|
| 273 |
+
neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax) # prevent exp overflow
|
| 274 |
+
sig = 1.0 / (1.0 + tl.exp(neg_ax))
|
| 275 |
+
sg = ALPHA * sig * (1.0 - sig)
|
| 276 |
+
|
| 277 |
+
total_gV = g_V + acc
|
| 278 |
+
grad_v_pre = g_s * sg + total_gV
|
| 279 |
+
|
| 280 |
+
tl.store(GRAD_BETA_ptr + off, grad_v_pre * v_prev, mask=mask)
|
| 281 |
+
tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask)
|
| 282 |
+
tl.store(GRAD_VTH_ptr + off, -g_s * sg - total_gV * spike, mask=mask)
|
| 283 |
+
|
| 284 |
+
acc = grad_v_pre * beta
|
| 285 |
+
|
| 286 |
+
tl.store(GRAD_INIT_ptr + cols, acc, mask=mask)
|
| 287 |
+
|
| 288 |
+
# ============================================================
|
| 289 |
+
# Fused PLIF kernels with row-parameter beta/v_th
|
| 290 |
+
# (constant across K steps — e.g., ParametricLIFNode scalars)
|
| 291 |
+
# ============================================================
|
| 292 |
+
|
| 293 |
+
@triton.jit
|
| 294 |
+
def _fused_plif_fwd_rowparam_kernel(
|
| 295 |
+
BETA_ROW_ptr, U_ptr, VTH_ROW_ptr, INIT_ptr,
|
| 296 |
+
SPIKE_ptr, VPOST_ptr,
|
| 297 |
+
K, num_cols,
|
| 298 |
+
BLOCK: tl.constexpr,
|
| 299 |
+
):
|
| 300 |
+
"""Fused PLIF forward with row-parameter beta and v_th.
|
| 301 |
+
|
| 302 |
+
beta and v_th are (*shape) — constant across K steps, loaded once into registers.
|
| 303 |
+
Reduces global memory reads from 3 per step (beta, u, v_th) to 1 (u only).
|
| 304 |
+
"""
|
| 305 |
+
pid = tl.program_id(0)
|
| 306 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 307 |
+
mask = cols < num_cols
|
| 308 |
+
|
| 309 |
+
v = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 310 |
+
beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 311 |
+
vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 312 |
+
|
| 313 |
+
for k in range(K):
|
| 314 |
+
off = k * num_cols + cols
|
| 315 |
+
u = tl.load(U_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
v_pre = beta * v + u
|
| 318 |
+
spike = tl.where(v_pre >= vth, 1.0, 0.0)
|
| 319 |
+
v = v_pre - vth * spike
|
| 320 |
+
|
| 321 |
+
tl.store(SPIKE_ptr + off, spike, mask=mask)
|
| 322 |
+
tl.store(VPOST_ptr + off, v, mask=mask)
|
| 323 |
+
|
| 324 |
+
@triton.jit
|
| 325 |
+
def _fused_plif_bwd_rowparam_kernel(
|
| 326 |
+
BETA_ROW_ptr, VTH_ROW_ptr, INIT_ptr, VPOST_ptr, SPIKE_ptr,
|
| 327 |
+
GRAD_SPIKE_ptr, GRAD_VPOST_ptr,
|
| 328 |
+
GRAD_BETA_ROW_ptr, GRAD_U_ptr, GRAD_VTH_ROW_ptr, GRAD_INIT_ptr,
|
| 329 |
+
K, num_cols, ALPHA,
|
| 330 |
+
BLOCK: tl.constexpr,
|
| 331 |
+
):
|
| 332 |
+
"""Fused PLIF backward with row-parameter beta/v_th.
|
| 333 |
+
|
| 334 |
+
Gradients for beta and v_th are accumulated over K steps (reduction in registers).
|
| 335 |
+
Returns grad_beta_row (*shape) and grad_v_th_row (*shape) instead of per-step gradients.
|
| 336 |
+
"""
|
| 337 |
+
pid = tl.program_id(0)
|
| 338 |
+
cols = pid * BLOCK + tl.arange(0, BLOCK)
|
| 339 |
+
mask = cols < num_cols
|
| 340 |
+
|
| 341 |
+
beta = tl.load(BETA_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 342 |
+
vth = tl.load(VTH_ROW_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 343 |
+
|
| 344 |
+
acc = tl.zeros([BLOCK], dtype=tl.float32)
|
| 345 |
+
acc_grad_beta = tl.zeros([BLOCK], dtype=tl.float32)
|
| 346 |
+
acc_grad_vth = tl.zeros([BLOCK], dtype=tl.float32)
|
| 347 |
+
|
| 348 |
+
for k_rev in range(K):
|
| 349 |
+
k = K - 1 - k_rev
|
| 350 |
+
off = k * num_cols + cols
|
| 351 |
+
|
| 352 |
+
v_post = tl.load(VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 353 |
+
spike = tl.load(SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 354 |
+
|
| 355 |
+
g_s = tl.load(GRAD_SPIKE_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 356 |
+
g_V = tl.load(GRAD_VPOST_ptr + off, mask=mask, other=0.0).to(tl.float32)
|
| 357 |
+
|
| 358 |
+
if k > 0:
|
| 359 |
+
v_prev = tl.load(
|
| 360 |
+
VPOST_ptr + (k - 1) * num_cols + cols,
|
| 361 |
+
mask=mask, other=0.0,
|
| 362 |
+
).to(tl.float32)
|
| 363 |
+
else:
|
| 364 |
+
v_prev = tl.load(INIT_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
| 365 |
+
|
| 366 |
+
# Sigmoid surrogate gradient
|
| 367 |
+
x = v_post - vth * (1.0 - spike)
|
| 368 |
+
neg_ax = -ALPHA * x
|
| 369 |
+
neg_ax = tl.where(neg_ax > 88.0, 88.0, neg_ax)
|
| 370 |
+
sig = 1.0 / (1.0 + tl.exp(neg_ax))
|
| 371 |
+
sg = ALPHA * sig * (1.0 - sig)
|
| 372 |
+
|
| 373 |
+
total_gV = g_V + acc
|
| 374 |
+
grad_v_pre = g_s * sg + total_gV
|
| 375 |
+
|
| 376 |
+
tl.store(GRAD_U_ptr + off, grad_v_pre, mask=mask)
|
| 377 |
+
|
| 378 |
+
# Accumulate gradients for row parameters (reduction over K in registers)
|
| 379 |
+
acc_grad_beta += grad_v_pre * v_prev
|
| 380 |
+
acc_grad_vth += -g_s * sg - total_gV * spike
|
| 381 |
+
|
| 382 |
+
acc = grad_v_pre * beta
|
| 383 |
+
|
| 384 |
+
tl.store(GRAD_INIT_ptr + cols, acc, mask=mask)
|
| 385 |
+
tl.store(GRAD_BETA_ROW_ptr + cols, acc_grad_beta, mask=mask)
|
| 386 |
+
tl.store(GRAD_VTH_ROW_ptr + cols, acc_grad_vth, mask=mask)
|
| 387 |
+
|
| 388 |
+
class _TritonPLIFRowParamForward(torch.autograd.Function):
|
| 389 |
+
"""Fused Triton PLIF with row-parameter beta/v_th.
|
| 390 |
+
|
| 391 |
+
For neurons with constant beta/v_th across K steps (ParametricLIFNode).
|
| 392 |
+
Eliminates expand+contiguous for beta/v_th tensors, reduces memory I/O by ~40%.
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
_BLOCK = 128
|
| 396 |
+
|
| 397 |
+
@staticmethod
|
| 398 |
+
def forward(ctx, beta_row, u, v_th_row, v_init, alpha):
|
| 399 |
+
beta_row_c = beta_row.contiguous()
|
| 400 |
+
u_c = u.contiguous()
|
| 401 |
+
v_th_row_c = v_th_row.contiguous()
|
| 402 |
+
v_init_c = v_init.contiguous()
|
| 403 |
+
|
| 404 |
+
K = u_c.shape[0]
|
| 405 |
+
num_cols = u_c[0].numel()
|
| 406 |
+
|
| 407 |
+
spike = torch.empty_like(u_c)
|
| 408 |
+
V_post = torch.empty_like(u_c)
|
| 409 |
+
|
| 410 |
+
BLOCK = _TritonPLIFRowParamForward._BLOCK
|
| 411 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 412 |
+
|
| 413 |
+
_fused_plif_fwd_rowparam_kernel[grid](
|
| 414 |
+
beta_row_c, u_c, v_th_row_c, v_init_c,
|
| 415 |
+
spike, V_post,
|
| 416 |
+
K, num_cols,
|
| 417 |
+
BLOCK=BLOCK,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
if any(ctx.needs_input_grad[:4]):
|
| 421 |
+
ctx.save_for_backward(beta_row_c, v_th_row_c, v_init_c, V_post, spike)
|
| 422 |
+
ctx.K = K
|
| 423 |
+
ctx.num_cols = num_cols
|
| 424 |
+
ctx.alpha = alpha
|
| 425 |
+
|
| 426 |
+
return spike, V_post
|
| 427 |
+
|
| 428 |
+
@staticmethod
|
| 429 |
+
def backward(ctx, grad_spike, grad_V_post):
|
| 430 |
+
beta_row, v_th_row, v_init, V_post, spike = ctx.saved_tensors
|
| 431 |
+
K = ctx.K
|
| 432 |
+
num_cols = ctx.num_cols
|
| 433 |
+
alpha = ctx.alpha
|
| 434 |
+
|
| 435 |
+
if grad_spike is None:
|
| 436 |
+
grad_spike = torch.zeros_like(spike)
|
| 437 |
+
if grad_V_post is None:
|
| 438 |
+
grad_V_post = torch.zeros_like(V_post)
|
| 439 |
+
|
| 440 |
+
grad_spike_c = grad_spike.contiguous()
|
| 441 |
+
grad_V_post_c = grad_V_post.contiguous()
|
| 442 |
+
|
| 443 |
+
grad_beta_row = torch.empty_like(beta_row)
|
| 444 |
+
grad_u = torch.empty_like(V_post)
|
| 445 |
+
grad_v_th_row = torch.empty_like(v_th_row)
|
| 446 |
+
grad_v_init = torch.empty_like(v_init)
|
| 447 |
+
|
| 448 |
+
BLOCK = _TritonPLIFRowParamForward._BLOCK
|
| 449 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 450 |
+
|
| 451 |
+
_fused_plif_bwd_rowparam_kernel[grid](
|
| 452 |
+
beta_row, v_th_row, v_init, V_post, spike,
|
| 453 |
+
grad_spike_c, grad_V_post_c,
|
| 454 |
+
grad_beta_row, grad_u, grad_v_th_row, grad_v_init,
|
| 455 |
+
K, num_cols, float(alpha),
|
| 456 |
+
BLOCK=BLOCK,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
return grad_beta_row, grad_u, grad_v_th_row, grad_v_init, None
|
| 460 |
+
|
| 461 |
+
class _TritonPLIFForward(torch.autograd.Function):
|
| 462 |
+
"""Fused Triton PLIF forward + backward.
|
| 463 |
+
|
| 464 |
+
Single-pass sequential scan replaces the 3-phase approach:
|
| 465 |
+
Phase 1 (linear scan) + Phase 2 (spike iteration) + Phase 3 (correction)
|
| 466 |
+
→ 1 fused kernel with inline spike detection + soft reset
|
| 467 |
+
|
| 468 |
+
Advantages:
|
| 469 |
+
- 1 kernel launch (vs 3-4 launches + ~10 element-wise ops)
|
| 470 |
+
- Exact computation (no iteration convergence issues)
|
| 471 |
+
- Less memory (no intermediate V_L, delta_S, delta_S_prev)
|
| 472 |
+
- Higher precision (fp32 accumulation, no bf16 intermediate store/load)
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
_BLOCK = 128
|
| 476 |
+
|
| 477 |
+
@staticmethod
|
| 478 |
+
def forward(ctx, beta, u, v_th, v_init, alpha):
|
| 479 |
+
beta_c = beta.contiguous()
|
| 480 |
+
u_c = u.contiguous()
|
| 481 |
+
v_th_c = v_th.contiguous()
|
| 482 |
+
v_init_c = v_init.contiguous()
|
| 483 |
+
|
| 484 |
+
K = beta_c.shape[0]
|
| 485 |
+
num_cols = beta_c[0].numel()
|
| 486 |
+
|
| 487 |
+
spike = torch.empty_like(u_c)
|
| 488 |
+
V_post = torch.empty_like(u_c)
|
| 489 |
+
|
| 490 |
+
BLOCK = _TritonPLIFForward._BLOCK
|
| 491 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 492 |
+
|
| 493 |
+
_fused_plif_fwd_kernel[grid](
|
| 494 |
+
beta_c, u_c, v_th_c, v_init_c,
|
| 495 |
+
spike, V_post,
|
| 496 |
+
K, num_cols,
|
| 497 |
+
BLOCK=BLOCK,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if any(ctx.needs_input_grad[:4]):
|
| 501 |
+
ctx.save_for_backward(beta_c, v_th_c, v_init_c, V_post, spike)
|
| 502 |
+
ctx.K = K
|
| 503 |
+
ctx.num_cols = num_cols
|
| 504 |
+
ctx.alpha = alpha
|
| 505 |
+
|
| 506 |
+
return spike, V_post
|
| 507 |
+
|
| 508 |
+
@staticmethod
|
| 509 |
+
def backward(ctx, grad_spike, grad_V_post):
|
| 510 |
+
beta, v_th, v_init, V_post, spike = ctx.saved_tensors
|
| 511 |
+
K = ctx.K
|
| 512 |
+
num_cols = ctx.num_cols
|
| 513 |
+
alpha = ctx.alpha
|
| 514 |
+
|
| 515 |
+
if grad_spike is None:
|
| 516 |
+
grad_spike = torch.zeros_like(spike)
|
| 517 |
+
if grad_V_post is None:
|
| 518 |
+
grad_V_post = torch.zeros_like(V_post)
|
| 519 |
+
|
| 520 |
+
grad_spike_c = grad_spike.contiguous()
|
| 521 |
+
grad_V_post_c = grad_V_post.contiguous()
|
| 522 |
+
|
| 523 |
+
grad_beta = torch.empty_like(beta)
|
| 524 |
+
grad_u = torch.empty_like(beta)
|
| 525 |
+
grad_v_th = torch.empty_like(v_th)
|
| 526 |
+
grad_v_init = torch.empty_like(v_init)
|
| 527 |
+
|
| 528 |
+
BLOCK = _TritonPLIFForward._BLOCK
|
| 529 |
+
grid = ((num_cols + BLOCK - 1) // BLOCK,)
|
| 530 |
+
|
| 531 |
+
_fused_plif_bwd_kernel[grid](
|
| 532 |
+
beta, v_th, v_init, V_post, spike,
|
| 533 |
+
grad_spike_c, grad_V_post_c,
|
| 534 |
+
grad_beta, grad_u, grad_v_th, grad_v_init,
|
| 535 |
+
K, num_cols, float(alpha),
|
| 536 |
+
BLOCK=BLOCK,
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
return grad_beta, grad_u, grad_v_th, grad_v_init, None
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
# ============================================================
|
| 543 |
+
# Hillis-Steele parallel prefix scan (CPU fallback)
|
| 544 |
+
# ============================================================
|
| 545 |
+
|
| 546 |
+
def hillis_steele_scan(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 547 |
+
"""
|
| 548 |
+
Hillis-Steele 并行前缀扫描:计算仿射映射序列的所有前缀复合。
|
| 549 |
+
|
| 550 |
+
给定仿射映射 f_k(x) = a[k] * x + b[k], k = 0, ..., K-1,
|
| 551 |
+
计算前缀复合 F_k = f_k ∘ f_{k-1} ∘ ... ∘ f_0,
|
| 552 |
+
使得 V[k] = F_k(v_init) = A[k] * v_init + B[k]。
|
| 553 |
+
|
| 554 |
+
复合规则: (a2, b2) ∘ (a1, b1) = (a2 * a1, a2 * b1 + b2)
|
| 555 |
+
|
| 556 |
+
实现使用 torch.cat 重建张量(无原地操作),完全兼容 autograd。
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
a: (K, *shape) — 乘性系数(如 β)
|
| 560 |
+
b: (K, *shape) — 加性项(如 α·I)
|
| 561 |
+
|
| 562 |
+
Returns:
|
| 563 |
+
A: (K, *shape) — 前缀积 A[k] = ∏_{j=0}^{k} a[j]
|
| 564 |
+
B: (K, *shape) — 前缀和 B[k] 使得 V[k] = A[k] * v_init + B[k]
|
| 565 |
+
|
| 566 |
+
并行深度: O(log K)
|
| 567 |
+
工作量: O(K * log K)
|
| 568 |
+
"""
|
| 569 |
+
K = a.shape[0]
|
| 570 |
+
A = a
|
| 571 |
+
B = b
|
| 572 |
+
|
| 573 |
+
d = 1
|
| 574 |
+
while d < K:
|
| 575 |
+
A_new_tail = A[d:] * A[:-d]
|
| 576 |
+
B_new_tail = A[d:] * B[:-d] + B[d:]
|
| 577 |
+
|
| 578 |
+
A = torch.cat([A[:d], A_new_tail], dim=0)
|
| 579 |
+
B = torch.cat([B[:d], B_new_tail], dim=0)
|
| 580 |
+
|
| 581 |
+
d *= 2
|
| 582 |
+
|
| 583 |
+
return A, B
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
# ============================================================
|
| 587 |
+
# Public API: linear_recurrence
|
| 588 |
+
# ============================================================
|
| 589 |
+
|
| 590 |
+
def linear_recurrence(beta: torch.Tensor, u: torch.Tensor, v_init: torch.Tensor) -> torch.Tensor:
|
| 591 |
+
"""
|
| 592 |
+
求解线性递推: V[k] = beta[k] * V[k-1] + u[k], V[-1] = v_init
|
| 593 |
+
|
| 594 |
+
CUDA 后端: Triton fused kernel(1 次 kernel launch,O(K) 工作量)
|
| 595 |
+
CPU 后端: Hillis-Steele parallel scan(O(K log K) 工作量)
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
beta: (K, *shape) — 衰减系数,值域 (0, 1)
|
| 599 |
+
u: (K, *shape) — 输入项
|
| 600 |
+
v_init: (*shape) — 初始状态
|
| 601 |
+
|
| 602 |
+
Returns:
|
| 603 |
+
V: (K, *shape) — 所有 K 步的状态
|
| 604 |
+
"""
|
| 605 |
+
if _HAS_TRITON and beta.is_cuda:
|
| 606 |
+
return _TritonLinearRecurrence.apply(beta, u, v_init)
|
| 607 |
+
# CPU fallback
|
| 608 |
+
A, B = hillis_steele_scan(beta, u)
|
| 609 |
+
V = A * v_init.unsqueeze(0) + B
|
| 610 |
+
return V
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
# ============================================================
|
| 614 |
+
# PLIF parallel forward (with spike iteration)
|
| 615 |
+
# ============================================================
|
| 616 |
+
|
| 617 |
+
def plif_parallel_forward(
|
| 618 |
+
beta: torch.Tensor,
|
| 619 |
+
u: torch.Tensor,
|
| 620 |
+
v_th: torch.Tensor,
|
| 621 |
+
v_init: torch.Tensor,
|
| 622 |
+
max_iter: int = 3,
|
| 623 |
+
surrogate_function=None,
|
| 624 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 625 |
+
"""
|
| 626 |
+
PLIF 神经元的并行前向传播(soft reset,surrogate gradient 兼容)。
|
| 627 |
+
|
| 628 |
+
求解:
|
| 629 |
+
V_pre[k] = beta[k] * V_post[k-1] + u[k]
|
| 630 |
+
s[k] = Θ(V_pre[k] - v_th[k])
|
| 631 |
+
V_post[k] = V_pre[k] - v_th[k] * s[k]
|
| 632 |
+
|
| 633 |
+
方法:
|
| 634 |
+
Phase 1: 线性轨迹 parallel scan(有梯度)
|
| 635 |
+
Phase 2: spike 不动点迭代(detach,确定离散 spike pattern)
|
| 636 |
+
Phase 3: 用 converged spike pattern 重算 V_post(有梯度),
|
| 637 |
+
surrogate_function(V_pre - v_th) 生成可微 spike 输出
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
beta: (K, *shape) — 衰减系数
|
| 641 |
+
u: (K, *shape) — 输入 α·I
|
| 642 |
+
v_th: (K, *shape) — 动态阈值
|
| 643 |
+
v_init: (*shape) — 初始膜电位
|
| 644 |
+
max_iter: spike 不动点迭代次数上限
|
| 645 |
+
surrogate_function: SpikingJelly surrogate gradient 函数(如 surrogate.Sigmoid(alpha=4.0))
|
| 646 |
+
None 时退化为硬阈值(无梯度)
|
| 647 |
+
|
| 648 |
+
Returns:
|
| 649 |
+
spike: (K, *shape) — spike 模式(有 surrogate gradient)
|
| 650 |
+
V_post: (K, *shape) — 发放后膜电位
|
| 651 |
+
V_pre: (K, *shape) — 发放前膜电位(fused path 返回 None)
|
| 652 |
+
"""
|
| 653 |
+
# Fused Triton path: single-pass sequential scan (exact, no iteration)
|
| 654 |
+
# Replaces 3-phase approach with 1 kernel launch — ~3x faster forward, ~5x faster backward
|
| 655 |
+
if (_HAS_TRITON and beta.is_cuda and surrogate_function is not None
|
| 656 |
+
and hasattr(surrogate_function, 'alpha')
|
| 657 |
+
and type(surrogate_function).__name__ == 'Sigmoid'):
|
| 658 |
+
alpha = float(surrogate_function.alpha)
|
| 659 |
+
spike, V_post = _TritonPLIFForward.apply(beta, u, v_th, v_init, alpha)
|
| 660 |
+
return spike, V_post, None
|
| 661 |
+
|
| 662 |
+
# Fallback: 3-phase approach (CPU, non-Sigmoid surrogates, or no surrogate)
|
| 663 |
+
# Phase 1: 线性轨迹 V_L (假设从不发放)
|
| 664 |
+
V_L = linear_recurrence(beta, u, v_init) # (K, *shape)
|
| 665 |
+
|
| 666 |
+
# Phase 2: Spike 不动点迭代(全部 detach,不建立梯度图)
|
| 667 |
+
# 目的:确定哪些神经元在哪些步发放(离散决策)
|
| 668 |
+
with torch.no_grad():
|
| 669 |
+
V_L_det = V_L.detach()
|
| 670 |
+
beta_det = beta.detach()
|
| 671 |
+
v_th_det = v_th.detach()
|
| 672 |
+
v_init_det = v_init.detach() if isinstance(v_init, torch.Tensor) else v_init
|
| 673 |
+
|
| 674 |
+
spike_pattern = (V_L_det >= v_th_det).float()
|
| 675 |
+
|
| 676 |
+
for _ in range(max_iter - 1):
|
| 677 |
+
# 计算 ΔS: ΔS[k] = beta[k] * ΔS[k-1] + v_th[k] * s[k]
|
| 678 |
+
delta_S = linear_recurrence(
|
| 679 |
+
beta_det, v_th_det * spike_pattern,
|
| 680 |
+
torch.zeros_like(v_init_det) if isinstance(v_init_det, torch.Tensor)
|
| 681 |
+
else torch.zeros_like(V_L_det[0]),
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
# ΔS_prev = ΔS[k-1](位移一步)
|
| 685 |
+
delta_S_prev = torch.zeros_like(delta_S)
|
| 686 |
+
delta_S_prev[1:] = delta_S[:-1]
|
| 687 |
+
|
| 688 |
+
# V_pre = V_L - beta * ΔS_prev
|
| 689 |
+
V_pre_det = V_L_det - beta_det * delta_S_prev
|
| 690 |
+
|
| 691 |
+
# 更新 spike
|
| 692 |
+
spike_new = (V_pre_det >= v_th_det).float()
|
| 693 |
+
|
| 694 |
+
# 收敛检查
|
| 695 |
+
if torch.equal(spike_new, spike_pattern):
|
| 696 |
+
break
|
| 697 |
+
spike_pattern = spike_new
|
| 698 |
+
|
| 699 |
+
# Phase 3: 用 converged spike pattern 重算 V_post(有完整梯度)
|
| 700 |
+
# spike_pattern 是 detached 的,作为常数参与计算
|
| 701 |
+
# 梯度通过 u, v_th, beta, v_init 流动
|
| 702 |
+
u_eff = u - v_th * spike_pattern
|
| 703 |
+
V_post = linear_recurrence(beta, u_eff, v_init) # (K, *shape)
|
| 704 |
+
|
| 705 |
+
# 重建 V_pre(有梯度,用于 surrogate gradient)
|
| 706 |
+
V_post_prev = torch.zeros_like(V_post)
|
| 707 |
+
if isinstance(v_init, torch.Tensor):
|
| 708 |
+
V_post_prev[0] = v_init
|
| 709 |
+
V_post_prev[1:] = V_post[:-1]
|
| 710 |
+
V_pre = beta * V_post_prev + u
|
| 711 |
+
|
| 712 |
+
# 生成可微 spike 输出
|
| 713 |
+
if surrogate_function is not None:
|
| 714 |
+
# forward: Heaviside(V_pre - v_th), backward: surrogate gradient
|
| 715 |
+
spike = surrogate_function(V_pre - v_th)
|
| 716 |
+
else:
|
| 717 |
+
# 退化模式:硬阈值,无梯度
|
| 718 |
+
spike = (V_pre >= v_th).float()
|
| 719 |
+
|
| 720 |
+
return spike, V_post, V_pre
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
def plif_rowparam_forward(
|
| 724 |
+
beta_row: torch.Tensor,
|
| 725 |
+
u: torch.Tensor,
|
| 726 |
+
v_th_row: torch.Tensor,
|
| 727 |
+
v_init: torch.Tensor,
|
| 728 |
+
surrogate_function=None,
|
| 729 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 730 |
+
"""
|
| 731 |
+
行参数 PLIF 前向:beta 和 v_th 在 K 步中保持恒定。
|
| 732 |
+
|
| 733 |
+
比 plif_parallel_forward 快 ~40%(省去 expand+contiguous,减少 2/3 显存读取)。
|
| 734 |
+
用于 ParametricLIFNode(固定 beta/v_th)或合并多个固定参数神经元。
|
| 735 |
+
|
| 736 |
+
Args:
|
| 737 |
+
beta_row: (*shape) — 每列的衰减率(所有 K 步相同)
|
| 738 |
+
u: (K, *shape) — 每步输入
|
| 739 |
+
v_th_row: (*shape) — 每列的阈值(所有 K 步相同)
|
| 740 |
+
v_init: (*shape) — 初始膜电位
|
| 741 |
+
surrogate_function: surrogate gradient 函数
|
| 742 |
+
|
| 743 |
+
Returns:
|
| 744 |
+
spike: (K, *shape) — spike 模式
|
| 745 |
+
V_post: (K, *shape) — 发放后膜电位
|
| 746 |
+
"""
|
| 747 |
+
if (_HAS_TRITON and u.is_cuda and surrogate_function is not None
|
| 748 |
+
and hasattr(surrogate_function, 'alpha')
|
| 749 |
+
and type(surrogate_function).__name__ == 'Sigmoid'):
|
| 750 |
+
alpha = float(surrogate_function.alpha)
|
| 751 |
+
spike, V_post = _TritonPLIFRowParamForward.apply(
|
| 752 |
+
beta_row, u, v_th_row, v_init, alpha,
|
| 753 |
+
)
|
| 754 |
+
return spike, V_post
|
| 755 |
+
|
| 756 |
+
# Fallback: expand to full (K, *shape) and use standard path
|
| 757 |
+
K = u.shape[0]
|
| 758 |
+
beta = beta_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous()
|
| 759 |
+
v_th = v_th_row.unsqueeze(0).expand(K, *u.shape[1:]).contiguous()
|
| 760 |
+
spike, V_post, _ = plif_parallel_forward(beta, u, v_th, v_init, surrogate_function=surrogate_function)
|
| 761 |
+
return spike, V_post
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
def plif_fixed_param_forward(
|
| 765 |
+
beta,
|
| 766 |
+
u: torch.Tensor,
|
| 767 |
+
v_th,
|
| 768 |
+
v_init: torch.Tensor,
|
| 769 |
+
max_iter: int = 3,
|
| 770 |
+
surrogate_function=None,
|
| 771 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 772 |
+
"""
|
| 773 |
+
固定参数 PLIF 神经元的并行前向(如输出神经元、FFN 神经元)。
|
| 774 |
+
|
| 775 |
+
ParametricLIFNode 方程: V[k] = beta * V[k-1] + (1-beta) * x[k]
|
| 776 |
+
其中 beta = 1/(1+exp(w)), 可为 scalar tensor(保持梯度流向 w)。
|
| 777 |
+
|
| 778 |
+
scalar/0-dim beta 和 v_th 使用 row-param 内核(无需 expand 到 (K, *shape))。
|
| 779 |
+
|
| 780 |
+
Args:
|
| 781 |
+
beta: 衰减率 — scalar float、0-dim tensor 或 (K, *shape) tensor
|
| 782 |
+
u: (K, *shape) — 输入(已乘以 (1-beta))
|
| 783 |
+
v_th: 阈值 — scalar float、0-dim tensor 或 (K, *shape) tensor
|
| 784 |
+
v_init: (*shape) — 初始膜电位
|
| 785 |
+
max_iter: spike 迭代次数
|
| 786 |
+
surrogate_function: surrogate gradient 函数
|
| 787 |
+
|
| 788 |
+
Returns:
|
| 789 |
+
spike: (K, *shape) — spike 模式
|
| 790 |
+
V_post: (K, *shape) — 发放后膜电位
|
| 791 |
+
"""
|
| 792 |
+
K = u.shape[0]
|
| 793 |
+
shape = u.shape[1:]
|
| 794 |
+
|
| 795 |
+
# Row-param fast path: beta 和 v_th 都是 scalar/0-dim → 扩展为 (*shape) 行向量
|
| 796 |
+
beta_is_scalar = isinstance(beta, torch.Tensor) and beta.dim() == 0
|
| 797 |
+
beta_is_float = not isinstance(beta, torch.Tensor)
|
| 798 |
+
vth_is_scalar = isinstance(v_th, torch.Tensor) and v_th.dim() == 0
|
| 799 |
+
vth_is_float = not isinstance(v_th, torch.Tensor)
|
| 800 |
+
|
| 801 |
+
if (beta_is_scalar or beta_is_float) and (vth_is_scalar or vth_is_float):
|
| 802 |
+
# Build row vectors (*shape)
|
| 803 |
+
if beta_is_scalar:
|
| 804 |
+
beta_row = beta.expand(*shape).contiguous()
|
| 805 |
+
else:
|
| 806 |
+
beta_row = torch.full(shape, beta, device=u.device, dtype=u.dtype)
|
| 807 |
+
if vth_is_scalar:
|
| 808 |
+
v_th_row = v_th.expand(*shape).contiguous()
|
| 809 |
+
else:
|
| 810 |
+
v_th_row = torch.full(shape, v_th, device=u.device, dtype=u.dtype)
|
| 811 |
+
return plif_rowparam_forward(beta_row, u, v_th_row, v_init, surrogate_function)
|
| 812 |
+
|
| 813 |
+
# Full-tensor path: expand to (K, *shape) if needed
|
| 814 |
+
if isinstance(beta, torch.Tensor):
|
| 815 |
+
if beta.dim() == 0:
|
| 816 |
+
beta = beta.expand(K, *shape).contiguous()
|
| 817 |
+
else:
|
| 818 |
+
beta = torch.full_like(u, beta)
|
| 819 |
+
|
| 820 |
+
if isinstance(v_th, torch.Tensor):
|
| 821 |
+
if v_th.dim() == 0:
|
| 822 |
+
v_th = v_th.expand(K, *shape).contiguous()
|
| 823 |
+
else:
|
| 824 |
+
v_th = torch.full_like(u, v_th)
|
| 825 |
+
|
| 826 |
+
spike, V_post, _ = plif_parallel_forward(
|
| 827 |
+
beta, u, v_th, v_init, max_iter, surrogate_function,
|
| 828 |
+
)
|
| 829 |
+
return spike, V_post
|
atomic_ops/plif_node.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PLIFNode: D 维固定参数 PLIF 神经元(设计文档 5.5 "普通 SNN 神经元")
|
| 3 |
+
|
| 4 |
+
与 SelectivePLIFNode 的区别:
|
| 5 |
+
SelectivePLIF: β(t), α(t), V_th(t) 由输入每步动态计算(选择性记忆)
|
| 6 |
+
PLIFNode: β, V_th 为 D 维可学习参数,训练后固定(信号转换)
|
| 7 |
+
|
| 8 |
+
每个维度有独立的可学习参数:
|
| 9 |
+
β_d = sigmoid(w_d): 时间常数(衰减率)
|
| 10 |
+
V_th_d: 发放阈值
|
| 11 |
+
|
| 12 |
+
动力学(与 ParametricLIF 一致):
|
| 13 |
+
V[t] = β · V[t-1] + (1-β) · x[t]
|
| 14 |
+
s[t] = Θ(V[t] - V_th) (surrogate gradient)
|
| 15 |
+
V[t] -= V_th · s[t] (soft reset)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from spikingjelly.activation_based import base, surrogate
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PLIFNode(base.MemoryModule):
|
| 26 |
+
"""
|
| 27 |
+
D 维固定参数 PLIF 神经元。
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dim: 神经元数量(每个维度独立参数)
|
| 31 |
+
init_tau: 初始时间常数 τ(β = 1 - 1/τ)
|
| 32 |
+
v_threshold: 初始发放阈值
|
| 33 |
+
surrogate_function: surrogate gradient 函数
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
dim: int,
|
| 39 |
+
init_tau: float = 2.0,
|
| 40 |
+
v_threshold: float = 0.5,
|
| 41 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
# D 维可学习参数(随机初始化,每个维度独立)
|
| 45 |
+
# w: 控制 β=sigmoid(w),随机产生不同时间常数
|
| 46 |
+
# init_w ± 0.5 → β ∈ ~[sigmoid(w-0.5), sigmoid(w+0.5)]
|
| 47 |
+
# tau=2.0 时 w=0, β ∈ ~[0.38, 0.62]
|
| 48 |
+
init_w = -math.log(init_tau - 1.0)
|
| 49 |
+
self.w = nn.Parameter(torch.empty(dim).normal_(init_w, 0.5))
|
| 50 |
+
# v_th: 发放阈值,U[0.5x, 1.5x] 均匀分布产生维度间多样性
|
| 51 |
+
self.v_th = nn.Parameter(torch.empty(dim).uniform_(
|
| 52 |
+
v_threshold * 0.5, v_threshold * 1.5,
|
| 53 |
+
))
|
| 54 |
+
self.surrogate_function = surrogate_function
|
| 55 |
+
# 膜电位状态(functional.reset_net 时重置为 0.)
|
| 56 |
+
self.register_memory('v', 0.)
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def beta(self):
|
| 60 |
+
"""D 维衰减率 β = sigmoid(w),值域 (0, 1)。"""
|
| 61 |
+
return torch.sigmoid(self.w)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
"""
|
| 65 |
+
单步前向传播。
|
| 66 |
+
|
| 67 |
+
V[t] = β · V[t-1] + (1-β) · x[t], spike = Θ(V-V_th), soft reset。
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
x: 输入电流, shape (batch, dim)
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
spike: 二值脉冲, shape (batch, dim), 值域 {0, 1}
|
| 74 |
+
"""
|
| 75 |
+
if isinstance(self.v, float):
|
| 76 |
+
self.v = torch.zeros_like(x)
|
| 77 |
+
beta = self.beta
|
| 78 |
+
self.v = beta * self.v + (1.0 - beta) * x
|
| 79 |
+
spike = self.surrogate_function(self.v - self.v_th)
|
| 80 |
+
self.v = self.v - spike * self.v_th # soft reset
|
| 81 |
+
return spike
|
atomic_ops/rms_norm.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RMSNorm: 残差流分支归一化(Pre-LN 模式)。
|
| 3 |
+
|
| 4 |
+
位置: h → RMSNorm → PLIFNode → SNN子层 → out_proj → 残差
|
| 5 |
+
作用: 控制送入 PLIFNode 的输入 scale,防止残差流漂移/爆炸。
|
| 6 |
+
仅归一化分支输入,残差流本身不被归一化。
|
| 7 |
+
|
| 8 |
+
对标 Qwen3/LLaMA 的 Pre-LN RMSNorm。
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RMSNorm(nn.Module):
|
| 16 |
+
"""Root Mean Square Layer Normalization.
|
| 17 |
+
|
| 18 |
+
x_norm = x / RMS(x) * weight
|
| 19 |
+
RMS(x) = sqrt(mean(x^2) + eps)
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
dim: 归一化维度
|
| 23 |
+
eps: 数值稳定性
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 29 |
+
self.eps = eps
|
| 30 |
+
|
| 31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
input_dtype = x.dtype
|
| 33 |
+
x = x.float()
|
| 34 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 35 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 36 |
+
return (self.weight * x).to(input_dtype)
|
atomic_ops/selective_plif.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SelectivePLIFNode: 动态参数的 PLIF 神经元
|
| 3 |
+
|
| 4 |
+
与标准 ParametricLIFNode 的区别:
|
| 5 |
+
- β(t), α(t), V_th(t) 作为外部参数每步传入,不是内部 nn.Parameter
|
| 6 |
+
- 本神经元无可训练参数,所有学习发生在 SNNBlock 的调制网络中
|
| 7 |
+
- 仅支持 step_mode='s'(单步模式)
|
| 8 |
+
- 仅支持 soft reset(v_reset=None)
|
| 9 |
+
|
| 10 |
+
状态方程:
|
| 11 |
+
V[t] = β(t) · V[t-1] + α(t) · I[t]
|
| 12 |
+
s[t] = Θ(V[t] - V_th(t)) (surrogate gradient)
|
| 13 |
+
V[t] -= V_th(t) · s[t] (soft reset)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from spikingjelly.activation_based import neuron, surrogate
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SelectivePLIFNode(neuron.BaseNode):
|
| 21 |
+
"""
|
| 22 |
+
隐状态空间的核心神经元。
|
| 23 |
+
|
| 24 |
+
接收外部动态计算的 β(t), α(t), V_th(t),执行:
|
| 25 |
+
charge → fire → soft reset
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
surrogate_function: surrogate gradient 函数,默认 Sigmoid(alpha=4.0)
|
| 29 |
+
detach_reset: 是否在 reset 时 detach spike,默认 False
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 35 |
+
detach_reset: bool = False,
|
| 36 |
+
):
|
| 37 |
+
# v_threshold=1.0 是占位值,实际使用外部传入的 v_th
|
| 38 |
+
# v_reset=None 触发 soft reset 模式,register_memory('v', 0.)
|
| 39 |
+
super().__init__(
|
| 40 |
+
v_threshold=1.0,
|
| 41 |
+
v_reset=None,
|
| 42 |
+
surrogate_function=surrogate_function,
|
| 43 |
+
detach_reset=detach_reset,
|
| 44 |
+
step_mode='s',
|
| 45 |
+
backend='torch',
|
| 46 |
+
store_v_seq=False,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def single_step_forward(
|
| 50 |
+
self,
|
| 51 |
+
x: torch.Tensor,
|
| 52 |
+
beta: torch.Tensor,
|
| 53 |
+
alpha: torch.Tensor,
|
| 54 |
+
v_th: torch.Tensor,
|
| 55 |
+
) -> torch.Tensor:
|
| 56 |
+
"""
|
| 57 |
+
单步前向传播。
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
x: 输入电流 I[t], shape (batch, D*N)
|
| 61 |
+
beta: 衰减率 β(t), shape (batch, D*N), 值域 (0, 1)
|
| 62 |
+
alpha: 写入增益 α(t), shape (batch, D*N), 值域 R+
|
| 63 |
+
v_th: 动态阈值 V_th(t), shape (batch, D*N), 值域 R+
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
spike: 二值脉冲 s[t], shape (batch, D*N), 值域 {0, 1}
|
| 67 |
+
"""
|
| 68 |
+
# Phase 0: 首步将 v 从 float 扩展为与输入同形的张量
|
| 69 |
+
self.v_float_to_tensor(x)
|
| 70 |
+
|
| 71 |
+
# Phase 1: Charge — 膜电位更新
|
| 72 |
+
# V[t] = β(t) · V[t-1] + α(t) · I[t]
|
| 73 |
+
self.v = beta * self.v + alpha * x
|
| 74 |
+
|
| 75 |
+
# Phase 2: Fire — 使用动态 v_th(不是 self.v_threshold)
|
| 76 |
+
# spike = Heaviside(V[t] - V_th(t)),反向用 surrogate gradient
|
| 77 |
+
spike = self.surrogate_function(self.v - v_th)
|
| 78 |
+
|
| 79 |
+
# Phase 3: Soft Reset — V[t] -= V_th(t) · s[t]
|
| 80 |
+
if self.detach_reset:
|
| 81 |
+
spike_d = spike.detach()
|
| 82 |
+
else:
|
| 83 |
+
spike_d = spike
|
| 84 |
+
self.v = self.v - spike_d * v_th
|
| 85 |
+
|
| 86 |
+
return spike
|
| 87 |
+
|
| 88 |
+
def extra_repr(self) -> str:
|
| 89 |
+
return (
|
| 90 |
+
f'v_reset={self.v_reset}, '
|
| 91 |
+
f'detach_reset={self.detach_reset}, '
|
| 92 |
+
f'step_mode={self.step_mode}, '
|
| 93 |
+
f'surrogate={self.surrogate_function}'
|
| 94 |
+
)
|
atomic_ops/snn_block.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SNNBlock: 完整的 SNN 隐状态空间 Block(并行化版本)
|
| 3 |
+
|
| 4 |
+
结构(每个 SNN 时间步):
|
| 5 |
+
spike_in {0,1}^D
|
| 6 |
+
├─→ W_in → I[t] ∈ R^{D*N}
|
| 7 |
+
├─→ W_β^(x) + b_β → σ → β(t)
|
| 8 |
+
├─→ W_α^(x) + b_α → softplus → α(t)
|
| 9 |
+
├─→ W_th^(x) + b_th → |·|+V_min → V_th(t)
|
| 10 |
+
├─→ W_gate → sigmoid → gate ∈ (0,1)^D
|
| 11 |
+
└─→ W_skip → I_skip ∈ R^D
|
| 12 |
+
|
| 13 |
+
SelectivePLIF(I, β, α, V_th) → s[t] ∈ {0,1}^{D*N}
|
| 14 |
+
|
| 15 |
+
W_out · V_post[t] ⊙ gate + I_skip → 连续输出 ∈ R^D
|
| 16 |
+
|
| 17 |
+
数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from spikingjelly.activation_based import base, layer, surrogate
|
| 25 |
+
|
| 26 |
+
from .selective_plif import SelectivePLIFNode
|
| 27 |
+
from .parallel_scan import plif_parallel_forward
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ====== Fused modulation activations (torch.compile) ======
|
| 31 |
+
# Fuse sigmoid + softplus + abs + alpha*I into single kernel.
|
| 32 |
+
# 7-8 separate element-wise kernels → 1 fused kernel, ~4x speedup on DN-sized tensors.
|
| 33 |
+
# First call triggers JIT compilation (~seconds); cached for subsequent calls.
|
| 34 |
+
|
| 35 |
+
@torch.compile(backend='inductor', fullgraph=True)
|
| 36 |
+
def _fused_modulation(raw_beta, b_beta, raw_alpha, b_alpha, raw_th, b_th, v_th_min, I_all):
|
| 37 |
+
beta = torch.sigmoid(raw_beta + b_beta)
|
| 38 |
+
alpha = F.softplus(raw_alpha + b_alpha)
|
| 39 |
+
v_th = v_th_min + torch.abs(raw_th + b_th)
|
| 40 |
+
u = alpha * I_all
|
| 41 |
+
return beta, u, v_th
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SNNBlock(base.MemoryModule):
|
| 45 |
+
"""
|
| 46 |
+
单个 SNN Block(并行化)。
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
D: 可见维度(Block 间通信的维度)
|
| 50 |
+
N: 状态扩展因子(每个通道的隐神经元数)
|
| 51 |
+
v_th_min: 动态阈值下限
|
| 52 |
+
surrogate_function: surrogate gradient 函数
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
D: int,
|
| 58 |
+
N: int = 8,
|
| 59 |
+
v_th_min: float = 0.1,
|
| 60 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 61 |
+
):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.D = D
|
| 64 |
+
self.N = N
|
| 65 |
+
self.v_th_min = v_th_min
|
| 66 |
+
DN = D * N
|
| 67 |
+
|
| 68 |
+
# ====== 六条并行输入投影(SNN 突触:spike 输入) ======
|
| 69 |
+
self.W_in = layer.Linear(D, DN, bias=False, step_mode='s')
|
| 70 |
+
self.W_beta_x = layer.Linear(D, DN, bias=False, step_mode='s')
|
| 71 |
+
self.W_alpha_x = layer.Linear(D, DN, bias=False, step_mode='s')
|
| 72 |
+
self.W_th_x = layer.Linear(D, DN, bias=False, step_mode='s')
|
| 73 |
+
self.W_gate = layer.Linear(D, D, bias=False, step_mode='s')
|
| 74 |
+
self.W_skip = layer.Linear(D, D, bias=False, step_mode='s')
|
| 75 |
+
|
| 76 |
+
# ====== β/α/V_th 仅依赖 spike_in(无 W^(V)·V 项) ======
|
| 77 |
+
|
| 78 |
+
# ====== 调制偏置(结构化初始化) ======
|
| 79 |
+
self.b_beta = nn.Parameter(torch.empty(DN))
|
| 80 |
+
self.b_alpha = nn.Parameter(torch.empty(DN))
|
| 81 |
+
self.b_th = nn.Parameter(torch.empty(DN))
|
| 82 |
+
|
| 83 |
+
# ====== 输出投影:D*N → D(SNN 突触) ======
|
| 84 |
+
self.W_out = layer.Linear(DN, D, bias=False, step_mode='s')
|
| 85 |
+
|
| 86 |
+
# ====== 隐状态空间神经元(D*N 个,动态参数) ======
|
| 87 |
+
self.hidden_neuron = SelectivePLIFNode(
|
| 88 |
+
surrogate_function=surrogate_function,
|
| 89 |
+
detach_reset=False,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# ====== 参数初始化 ======
|
| 93 |
+
self._initialize_parameters()
|
| 94 |
+
|
| 95 |
+
def _initialize_parameters(self):
|
| 96 |
+
"""功能引导初始化。"""
|
| 97 |
+
D, N = self.D, self.N
|
| 98 |
+
K_ref = 16
|
| 99 |
+
|
| 100 |
+
# 目标 β 分布:多时间尺度 [0.80, 0.99]
|
| 101 |
+
beta_values = torch.linspace(0.80, 0.99, N)
|
| 102 |
+
|
| 103 |
+
# ====== 1. β 偏置:logit-spaced + 维度间随机扰动 ======
|
| 104 |
+
b_beta_per_n = torch.log(beta_values / (1.0 - beta_values))
|
| 105 |
+
# 以 per_n 值为均值,加 N(0, 0.1) 扰动打破 D 个通道的对称性
|
| 106 |
+
self.b_beta.data.copy_(b_beta_per_n.repeat(D))
|
| 107 |
+
self.b_beta.data.add_(torch.empty_like(self.b_beta).normal_(0, 0.1))
|
| 108 |
+
|
| 109 |
+
# ====== 2. α 偏置:softplus(0.5413) ≈ 1.0 + 维度间随机扰动 ======
|
| 110 |
+
# 以 0.5413 为均值,N(0, 0.1) 扰动 → α ∈ ~[0.7, 1.3]
|
| 111 |
+
self.b_alpha.data.normal_(0.5413, 0.1)
|
| 112 |
+
|
| 113 |
+
# ====== 3. W^(x) 权重 ======
|
| 114 |
+
for lin in [self.W_in, self.W_gate, self.W_skip, self.W_out]:
|
| 115 |
+
nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
|
| 116 |
+
for lin in [self.W_beta_x, self.W_alpha_x, self.W_th_x]:
|
| 117 |
+
nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
|
| 118 |
+
lin.weight.data.mul_(0.1)
|
| 119 |
+
|
| 120 |
+
# ====== 4. W_in 时间尺度缩放 ======
|
| 121 |
+
scale_per_n = torch.sqrt(1.0 - beta_values ** 2) # (N,)
|
| 122 |
+
scale_DN = scale_per_n.repeat(D) # (D*N,)
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
self.W_in.weight.mul_(scale_DN.unsqueeze(1))
|
| 125 |
+
|
| 126 |
+
# ====== 5. b_th:σ_V 校准 ======
|
| 127 |
+
# σ_V = sqrt(p/3) * sqrt(1 - β^{2K})
|
| 128 |
+
# 其中 p 是输入 firing rate。旧版假设 p=0.5(σ_I=0.408),
|
| 129 |
+
# 但实际 input_neuron firing rate 约 0.07~0.45,深层更低。
|
| 130 |
+
# 用 p=0.15 保守估计,避免 v_th 过高导致死神经元。
|
| 131 |
+
p_assumed = 0.15
|
| 132 |
+
sigma_I_base = math.sqrt(p_assumed / 3.0)
|
| 133 |
+
sigma_V_per_n = sigma_I_base * torch.sqrt(
|
| 134 |
+
1.0 - beta_values ** (2 * K_ref)
|
| 135 |
+
)
|
| 136 |
+
target_p_fire = torch.linspace(0.25, 0.08, N)
|
| 137 |
+
z_scores = math.sqrt(2.0) * torch.erfinv(
|
| 138 |
+
2.0 * (1.0 - target_p_fire) - 1.0
|
| 139 |
+
)
|
| 140 |
+
target_V_th = sigma_V_per_n * z_scores
|
| 141 |
+
b_th_per_n = torch.clamp(target_V_th - self.v_th_min, min=0.05)
|
| 142 |
+
# 以 per_n 值为均值,加 N(0, 0.02) 扰动打破 D 个通道的对称性
|
| 143 |
+
self.b_th.data.copy_(b_th_per_n.repeat(D))
|
| 144 |
+
self.b_th.data.add_(torch.empty_like(self.b_th).normal_(0, 0.02))
|
| 145 |
+
|
| 146 |
+
# ====== 6. W_out 发放率均衡缩放 ======
|
| 147 |
+
out_scale_per_n = 1.0 / torch.sqrt(target_p_fire)
|
| 148 |
+
out_scale_per_n = out_scale_per_n / out_scale_per_n.mean()
|
| 149 |
+
out_scale_DN = out_scale_per_n.repeat(D)
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
self.W_out.weight.mul_(out_scale_DN.unsqueeze(0))
|
| 152 |
+
|
| 153 |
+
def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
"""
|
| 155 |
+
并行前向传播:使用 parallel scan 处理全序列。
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出(V_post 经 W_out 投影)
|
| 162 |
+
"""
|
| 163 |
+
TK, batch, D = spike_in_seq.shape
|
| 164 |
+
DN = self.D * self.N
|
| 165 |
+
|
| 166 |
+
# ====== Phase 1: 批量投影(全部 TK 帧同时计算)======
|
| 167 |
+
flat = spike_in_seq.reshape(TK * batch, D)
|
| 168 |
+
|
| 169 |
+
I_all = F.linear(flat, self.W_in.weight).reshape(TK, batch, DN)
|
| 170 |
+
raw_beta = F.linear(flat, self.W_beta_x.weight).reshape(TK, batch, DN)
|
| 171 |
+
raw_alpha = F.linear(flat, self.W_alpha_x.weight).reshape(TK, batch, DN)
|
| 172 |
+
raw_th = F.linear(flat, self.W_th_x.weight).reshape(TK, batch, DN)
|
| 173 |
+
gate_all = torch.sigmoid(
|
| 174 |
+
F.linear(flat, self.W_gate.weight).reshape(TK, batch, D)
|
| 175 |
+
)
|
| 176 |
+
I_skip_all = F.linear(flat, self.W_skip.weight).reshape(TK, batch, D)
|
| 177 |
+
|
| 178 |
+
# ====== Phase 1b: 融合激活(torch.compile → 单 kernel)======
|
| 179 |
+
beta_all, u_hidden, v_th_all = _fused_modulation(
|
| 180 |
+
raw_beta, self.b_beta, raw_alpha, self.b_alpha,
|
| 181 |
+
raw_th, self.b_th, self.v_th_min, I_all,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# 获取隐神经元初始状态
|
| 185 |
+
v_init_hidden = self.hidden_neuron.v
|
| 186 |
+
if isinstance(v_init_hidden, float):
|
| 187 |
+
v_init_hidden = torch.zeros(batch, DN, device=flat.device, dtype=flat.dtype)
|
| 188 |
+
|
| 189 |
+
s_hidden, V_post_hidden, _ = plif_parallel_forward(
|
| 190 |
+
beta_all, u_hidden, v_th_all, v_init_hidden, max_iter=3,
|
| 191 |
+
surrogate_function=self.hidden_neuron.surrogate_function,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# 更新隐神经元状态(保存末步供下次调用)
|
| 195 |
+
self.hidden_neuron.v = V_post_hidden[-1].detach()
|
| 196 |
+
|
| 197 |
+
# ====== Phase 4: 输出投影(V_post → W_out: 连续梯度直通 β)======
|
| 198 |
+
# 用 V_post(膜电压)代替 spike 作为 W_out 输入,消除 surrogate 梯度瓶颈:
|
| 199 |
+
# spike 路径: ∂spike/∂β = surrogate'(V-v_th) · V_prev ≈ 0(大部分时刻)
|
| 200 |
+
# V_post 路径: ∂V_post/∂β = V_prev(无 surrogate 阻断,每步都有梯度)
|
| 201 |
+
v_flat = V_post_hidden.reshape(TK * batch, DN)
|
| 202 |
+
I_out_all = F.linear(v_flat, self.W_out.weight).reshape(TK, batch, D)
|
| 203 |
+
I_total_all = I_out_all * gate_all + I_skip_all # (TK, batch, D)
|
| 204 |
+
|
| 205 |
+
# output_neuron 已移除:连续值由层级 K 帧聚合处理
|
| 206 |
+
return I_total_all # (TK, batch, D), 连续值
|
| 207 |
+
|
| 208 |
+
def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor:
|
| 209 |
+
"""
|
| 210 |
+
单步前向传播(用于调试/兼容)。
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1}
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
continuous_out: 连续输出, shape (batch, D)
|
| 217 |
+
"""
|
| 218 |
+
V_prev = self.hidden_neuron.v
|
| 219 |
+
if isinstance(V_prev, float):
|
| 220 |
+
V_prev = torch.zeros(
|
| 221 |
+
spike_in.shape[0], self.D * self.N,
|
| 222 |
+
device=spike_in.device, dtype=spike_in.dtype,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
I_t = self.W_in(spike_in)
|
| 226 |
+
|
| 227 |
+
# β 调制仅依赖 spike_in
|
| 228 |
+
beta = torch.sigmoid(self.W_beta_x(spike_in) + self.b_beta)
|
| 229 |
+
alpha = F.softplus(self.W_alpha_x(spike_in) + self.b_alpha)
|
| 230 |
+
v_th = self.v_th_min + torch.abs(self.W_th_x(spike_in) + self.b_th)
|
| 231 |
+
|
| 232 |
+
gate = torch.sigmoid(self.W_gate(spike_in))
|
| 233 |
+
I_skip = self.W_skip(spike_in)
|
| 234 |
+
|
| 235 |
+
s_hidden = self.hidden_neuron(I_t, beta, alpha, v_th)
|
| 236 |
+
|
| 237 |
+
# 用 V_post(膜电压)做输出投影,与 forward_parallel 一致
|
| 238 |
+
V_post = self.hidden_neuron.v # 发放+重置后的膜电位
|
| 239 |
+
I_out = self.W_out(V_post)
|
| 240 |
+
I_total = I_out * gate + I_skip
|
| 241 |
+
|
| 242 |
+
return I_total # 连续值
|
atomic_ops/snn_decoder_layer.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SNNDecoderLayer: 单个 SNN 解码层(Pre-LN 连续残差流 + 动态 K 帧聚合)
|
| 3 |
+
|
| 4 |
+
RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → 残差
|
| 5 |
+
RMSNorm → PLIF → SNNFFN → 动态K聚合 → out_proj → 残差
|
| 6 |
+
|
| 7 |
+
动态 K:
|
| 8 |
+
- K 是最大步数(K_max),不是固定步数。不同 token 有效步数 ∈ [1, K_max]。
|
| 9 |
+
- 每个 token 的 K 帧 SNN 输出,学习自适应停止概率 p_halt
|
| 10 |
+
- PonderNet 几何分布加权:λ_k = p_k · ∏_{j<k}(1-p_j),归一化后加权聚合
|
| 11 |
+
- 不同 token 有效步数不同:简单 token 早停(E[K]小),复杂 token 用满步数
|
| 12 |
+
- ponder_cost 正则化:鼓励用更少步数完成简单 token 的处理
|
| 13 |
+
|
| 14 |
+
数学推导:
|
| 15 |
+
停止概率: p_k = σ(halt_proj(frame_k)) ∈ (0,1)
|
| 16 |
+
生存概率: S_k = ∏_{j=1}^{k-1} (1 - p_j) — 到第 k 步还没停
|
| 17 |
+
权重: λ_k = p_k · S_k — 恰好在第 k 步停止的概率
|
| 18 |
+
归一化: λ̂_k = λ_k / Σ_k λ_k — 确保权重和为 1
|
| 19 |
+
聚合: output = Σ_k λ̂_k · frame_k
|
| 20 |
+
代价: E[K] = Σ_k k · λ̂_k — 期望步数
|
| 21 |
+
|
| 22 |
+
K_max 设计原则:
|
| 23 |
+
K_max 越大,模型对复杂 token 的处理能力越强(更多步数可用),
|
| 24 |
+
但计算量和显存线性增长。K_max=32 允许 token 使用 1~32 步。
|
| 25 |
+
PonderNet 的 ponder_cost 正则化确保简单 token 不浪费步数。
|
| 26 |
+
|
| 27 |
+
K 帧层间聚合:
|
| 28 |
+
- SNN 子层输出 K 帧连续值(V_post 经投影),PonderNet 加权聚合为 1 per token
|
| 29 |
+
- 聚合后经 out_proj 投影,广播回 K 帧做残差
|
| 30 |
+
- 使 β 的时间动力学通过 K 帧聚合梯度有效传播
|
| 31 |
+
|
| 32 |
+
对标 Qwen3DecoderLayer(Pre-LN 模式完全等价):
|
| 33 |
+
Qwen3: RMSNorm → Attention → residual → RMSNorm → MLP → residual
|
| 34 |
+
SNN: RMSNorm → PLIF → SNNBlock → 动态K聚合 → out_proj → residual
|
| 35 |
+
→ RMSNorm → PLIF → SNNFFN → 动态K聚合 → out_proj → residual
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import math
|
| 39 |
+
|
| 40 |
+
import torch
|
| 41 |
+
import torch.nn as nn
|
| 42 |
+
import torch.nn.functional as F
|
| 43 |
+
from spikingjelly.activation_based import base, surrogate
|
| 44 |
+
|
| 45 |
+
from .plif_node import PLIFNode
|
| 46 |
+
from .rms_norm import RMSNorm
|
| 47 |
+
from .snn_block import SNNBlock
|
| 48 |
+
from .snn_ffn import SNNFFN
|
| 49 |
+
from .parallel_scan import plif_rowparam_forward
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ====== Fused halt weight computation (torch.compile) ======
|
| 53 |
+
# 7-8 个独立 element-wise kernel → 单 fused kernel
|
| 54 |
+
# sigmoid + clamp + log1p + cumsum + exp + normalize
|
| 55 |
+
# 首次调用触发 JIT 编译(~秒级),后续调用走缓存
|
| 56 |
+
|
| 57 |
+
@torch.compile(backend='inductor', fullgraph=True)
|
| 58 |
+
def _fused_geometric_halt(halt_logits):
|
| 59 |
+
"""融合计算 PonderNet 几何分布停止权重。
|
| 60 |
+
|
| 61 |
+
输入: halt_logits (seq_len, K, batch) — halt_proj 的原始输出
|
| 62 |
+
输出: halt_weights (seq_len, K, batch) — 归一化几何分布权重,sum=1
|
| 63 |
+
|
| 64 |
+
数学: p_k = σ(logit_k), S_k = ∏_{j<k}(1-p_j), λ_k = p_k·S_k, λ̂_k = λ_k/Σλ
|
| 65 |
+
"""
|
| 66 |
+
p_halt = torch.sigmoid(halt_logits).clamp(min=1e-7, max=1.0 - 1e-7)
|
| 67 |
+
log_1_minus_p = torch.log1p(-p_halt) # (seq_len, K, batch)
|
| 68 |
+
# Exclusive cumsum: log_survive[:, k, :] = Σ_{j<k} log(1-p_j)
|
| 69 |
+
# 避免 torch.cat: 用 cumsum([:, :-1]) 填充 [:, 1:]
|
| 70 |
+
log_survive = torch.zeros_like(log_1_minus_p)
|
| 71 |
+
log_survive[:, 1:, :] = torch.cumsum(log_1_minus_p[:, :-1, :], dim=1)
|
| 72 |
+
survive = torch.exp(log_survive) # (seq_len, K, batch)
|
| 73 |
+
halt_weights = p_halt * survive # λ_k = p_k · S_k
|
| 74 |
+
halt_weights = halt_weights / (halt_weights.sum(dim=1, keepdim=True) + 1e-8)
|
| 75 |
+
return halt_weights
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SNNDecoderLayer(base.MemoryModule):
|
| 79 |
+
"""
|
| 80 |
+
单个 SNN 解码层(连续残差流 + K 帧聚合版本)。
|
| 81 |
+
|
| 82 |
+
层间传递连续值 h (TK, batch, D),通过 PLIF 神经元转换为 spike,
|
| 83 |
+
输入 SNN 子层处理后,K 帧聚合为 1 per token,经 out_proj 投影,
|
| 84 |
+
广播回 K 帧做残差连接。
|
| 85 |
+
|
| 86 |
+
K 帧聚合使 β 的时间动力学(控制 K 步内的膜电位演化)产生可微分的
|
| 87 |
+
token 级效应,解决 β 梯度为纯噪声的问题。
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
D: 可见维度
|
| 91 |
+
N: 状态扩展因子
|
| 92 |
+
D_ff: FFN 中间层维度
|
| 93 |
+
v_th_min: SNNBlock 动态阈值下限
|
| 94 |
+
ffn_v_threshold: SNNFFN gate/up 神经元阈值
|
| 95 |
+
K: 每 token 的 SNN 时间步数
|
| 96 |
+
num_layers: 总层数(用于残差输出缩放 + SNNFFN down_proj 缩放)
|
| 97 |
+
layer_idx: 当前层索引
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
D: int,
|
| 103 |
+
N: int,
|
| 104 |
+
D_ff: int,
|
| 105 |
+
v_th_min: float,
|
| 106 |
+
ffn_v_threshold: float,
|
| 107 |
+
K: int = 16,
|
| 108 |
+
num_layers: int = 1,
|
| 109 |
+
layer_idx: int = 0,
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.D = D
|
| 113 |
+
self.K = K
|
| 114 |
+
|
| 115 |
+
self.snn_block = SNNBlock(
|
| 116 |
+
D=D, N=N, v_th_min=v_th_min,
|
| 117 |
+
)
|
| 118 |
+
self.snn_ffn = SNNFFN(
|
| 119 |
+
D=D, D_ff=D_ff,
|
| 120 |
+
output_v_threshold=ffn_v_threshold,
|
| 121 |
+
num_layers=num_layers,
|
| 122 |
+
layer_idx=layer_idx,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Pre-LN 分支归一化: h → RMSNorm → PLIFNode
|
| 126 |
+
self.block_norm = RMSNorm(D)
|
| 127 |
+
self.ffn_norm = RMSNorm(D)
|
| 128 |
+
|
| 129 |
+
# 输入神经元: RMSNorm(h) → V_post 膜电位激活(D 维可学习 β 和 V_th)
|
| 130 |
+
self.input_neuron1 = PLIFNode(
|
| 131 |
+
dim=D,
|
| 132 |
+
init_tau=2.0,
|
| 133 |
+
v_threshold=0.5,
|
| 134 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 135 |
+
)
|
| 136 |
+
self.input_neuron2 = PLIFNode(
|
| 137 |
+
dim=D,
|
| 138 |
+
init_tau=2.0,
|
| 139 |
+
v_threshold=0.5,
|
| 140 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# 输出投影(突触): spike (D) → 连续空间 (D)
|
| 144 |
+
self.block_out_proj = nn.Linear(D, D, bias=False)
|
| 145 |
+
self.ffn_out_proj = nn.Linear(D, D, bias=False)
|
| 146 |
+
|
| 147 |
+
# ====== 动态 K: 停止投影(突触: SNN 输出 → 停止概率) ======
|
| 148 |
+
# halt_proj: D → 1,每步每 token 产生一个停止 logit
|
| 149 |
+
# PonderNet 几何分布加权,替代 uniform mean 聚合
|
| 150 |
+
self.block_halt = nn.Linear(D, 1, bias=True)
|
| 151 |
+
self.ffn_halt = nn.Linear(D, 1, bias=True)
|
| 152 |
+
|
| 153 |
+
# 残差输出缩放初始化(GPT-2 style: σ = 0.02 / √(2·num_layers))
|
| 154 |
+
std = 0.02 / math.sqrt(2 * num_layers)
|
| 155 |
+
nn.init.normal_(self.block_out_proj.weight, std=std)
|
| 156 |
+
nn.init.normal_(self.ffn_out_proj.weight, std=std)
|
| 157 |
+
|
| 158 |
+
# halt 初始化: 小权重 + 负偏置 → p_halt ≈ 0.03 → 接近 uniform 聚合
|
| 159 |
+
# σ(-3.5) ≈ 0.029, 几何分布归一化后 λ_1/λ_K ≈ 1.5, 接近均匀
|
| 160 |
+
for halt in [self.block_halt, self.ffn_halt]:
|
| 161 |
+
nn.init.xavier_uniform_(halt.weight)
|
| 162 |
+
halt.weight.data.mul_(0.01)
|
| 163 |
+
nn.init.constant_(halt.bias, -3.5)
|
| 164 |
+
|
| 165 |
+
def _input_neuron_parallel(self, input_neuron, x):
|
| 166 |
+
"""
|
| 167 |
+
输入 PLIF 神经元的 parallel scan 前向传播。
|
| 168 |
+
|
| 169 |
+
完整 PLIF 动力学: V[t] = β·V[t-1] + (1-β)·x[t], spike = Θ(V-V_th), 软重置。
|
| 170 |
+
输出膜电位泄漏量 (1-β)·V_post 作为激活值——即每步因指数衰减将泄漏的量。
|
| 171 |
+
相比直接传递 V_post,泄漏量自然强调快响应神经元(大 1-β),
|
| 172 |
+
抑制慢记忆神经元(小 1-β),实现隐式的时间尺度加权。
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
input_neuron: PLIFNode 实例(D 维可学习 β 和 V_th)
|
| 176 |
+
x: (TK, batch, D) — 连续值输入
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
leak: (TK, batch, D) — 膜电位泄漏量 (1-β)·V_post
|
| 180 |
+
"""
|
| 181 |
+
TK, batch, D = x.shape
|
| 182 |
+
|
| 183 |
+
beta = input_neuron.beta # (D,)
|
| 184 |
+
u = (1.0 - beta) * x # (D,) broadcast → (TK, batch, D)
|
| 185 |
+
|
| 186 |
+
v_init = input_neuron.v
|
| 187 |
+
if isinstance(v_init, float):
|
| 188 |
+
v_init = torch.zeros(batch, D, device=x.device, dtype=x.dtype)
|
| 189 |
+
|
| 190 |
+
beta_row = beta.unsqueeze(0).expand(batch, D).contiguous()
|
| 191 |
+
v_th_row = input_neuron.v_th.unsqueeze(0).expand(batch, D).contiguous()
|
| 192 |
+
|
| 193 |
+
spike, V_post = plif_rowparam_forward(
|
| 194 |
+
beta_row, u, v_th_row, v_init,
|
| 195 |
+
surrogate_function=input_neuron.surrogate_function,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
input_neuron.v = V_post[-1].detach()
|
| 199 |
+
return (1.0 - beta) * V_post # 膜电位泄漏量
|
| 200 |
+
|
| 201 |
+
def _adaptive_aggregate(self, frames, halt_proj):
|
| 202 |
+
"""
|
| 203 |
+
PonderNet 式自适应 K 帧聚合(动态 K 核心,torch.compile 融合优化)。
|
| 204 |
+
|
| 205 |
+
每步计算停止概率 p_k,用几何分布权重加权聚合,
|
| 206 |
+
使不同 token 有不同的有效步数。
|
| 207 |
+
|
| 208 |
+
优化: _fused_geometric_halt 将 sigmoid+log1p+cumsum+exp+normalize
|
| 209 |
+
融合为单 inductor kernel(参见 snn_block._fused_modulation 同一模式)。
|
| 210 |
+
|
| 211 |
+
数学:
|
| 212 |
+
p_k = σ(halt_proj(frame_k)) — 停止概率
|
| 213 |
+
S_k = ∏_{j<k} (1-p_j) — 生存概率
|
| 214 |
+
λ_k = p_k · S_k — 几何分布权重
|
| 215 |
+
λ̂_k = λ_k / Σ λ_k — 归一化
|
| 216 |
+
output = Σ λ̂_k · frame_k — 加权聚合
|
| 217 |
+
E[K] = Σ k · λ̂_k — 期望步数(ponder cost)
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
frames: (seq_len, K, batch, D) — SNN 子层 K 帧输出
|
| 221 |
+
halt_proj: nn.Linear(D, 1) — 停止投影(突触)
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
aggregated: (seq_len, batch, D) — 加权聚合结果
|
| 225 |
+
ponder_cost: scalar — 期望步数均值(正则化用)
|
| 226 |
+
"""
|
| 227 |
+
seq_len, K, batch, D = frames.shape
|
| 228 |
+
|
| 229 |
+
# ====== 1. halt_proj matmul(cuBLAS)+ 融合几何权重(inductor) ======
|
| 230 |
+
halt_logits = halt_proj(frames).squeeze(-1) # (seq_len, K, batch)
|
| 231 |
+
halt_weights = _fused_geometric_halt(halt_logits) # (seq_len, K, batch), 归一化
|
| 232 |
+
|
| 233 |
+
# ====== 2. 加权聚合 ======
|
| 234 |
+
# (seq_len, K, batch, 1) × (seq_len, K, batch, D) → sum → (seq_len, batch, D)
|
| 235 |
+
aggregated = (frames * halt_weights.unsqueeze(-1)).sum(dim=1)
|
| 236 |
+
|
| 237 |
+
# ====== 3. Ponder cost: E[K] per token ======
|
| 238 |
+
steps = torch.arange(1, K + 1, device=frames.device, dtype=frames.dtype)
|
| 239 |
+
expected_k = (halt_weights * steps[None, :, None]).sum(dim=1) # (seq_len, batch)
|
| 240 |
+
ponder_cost = expected_k.mean() # scalar
|
| 241 |
+
|
| 242 |
+
return aggregated, ponder_cost, expected_k.detach()
|
| 243 |
+
|
| 244 |
+
def forward_parallel(self, h):
|
| 245 |
+
"""
|
| 246 |
+
并行前向传播:连续残差流 + 动态 K 帧聚合。
|
| 247 |
+
|
| 248 |
+
SNN 子层在 TK 维度处理(K 步时间动力学),输出后用 PonderNet
|
| 249 |
+
自适应聚合 K 帧(不同 token 有效步数不同),经 out_proj 投影后
|
| 250 |
+
广播回 TK 做残差。
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
h: (TK, batch, D) — 连续值输入
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
h: (TK, batch, D) — 连续值输出
|
| 257 |
+
ponder_cost: scalar — 两个子层的平均期望步数(正则化用)
|
| 258 |
+
"""
|
| 259 |
+
TK, batch, D = h.shape
|
| 260 |
+
K = self.K
|
| 261 |
+
seq_len = TK // K
|
| 262 |
+
|
| 263 |
+
# 子层 1: SNNBlock — RMSNorm → PLIFNode(V_post) → SNNBlock → 动态K聚合 → out_proj → 残差
|
| 264 |
+
v_in = self._input_neuron_parallel(self.input_neuron1, self.block_norm(h))
|
| 265 |
+
cont_block = self.snn_block.forward_parallel(v_in) # (TK, batch, D), 连续值
|
| 266 |
+
|
| 267 |
+
# 动态 K 帧聚合(PonderNet): (TK, batch, D) → (seq_len, K, batch, D) → 加权 → (seq_len, batch, D)
|
| 268 |
+
frames_block = cont_block.view(seq_len, K, batch, D)
|
| 269 |
+
combined_block, pc_block, ek_block = self._adaptive_aggregate(frames_block, self.block_halt)
|
| 270 |
+
res_block = self.block_out_proj(combined_block) # (seq_len, batch, D)
|
| 271 |
+
res_block = res_block - res_block.mean(dim=-1, keepdim=True) # 残差中心化
|
| 272 |
+
|
| 273 |
+
# 广播回 TK:每 token 的残差复制 K 份
|
| 274 |
+
h = h + res_block.repeat_interleave(K, dim=0)
|
| 275 |
+
|
| 276 |
+
# 子层 2: SNNFFN — RMSNorm → PLIFNode(V_post) → SNNFFN → 动态K聚合 → out_proj → 残差
|
| 277 |
+
v_in2 = self._input_neuron_parallel(self.input_neuron2, self.ffn_norm(h))
|
| 278 |
+
cont_ffn = self.snn_ffn.forward_parallel(v_in2) # (TK, batch, D), 连续值
|
| 279 |
+
|
| 280 |
+
frames_ffn = cont_ffn.view(seq_len, K, batch, D)
|
| 281 |
+
combined_ffn, pc_ffn, ek_ffn = self._adaptive_aggregate(frames_ffn, self.ffn_halt)
|
| 282 |
+
res_ffn = self.ffn_out_proj(combined_ffn)
|
| 283 |
+
res_ffn = res_ffn - res_ffn.mean(dim=-1, keepdim=True)
|
| 284 |
+
|
| 285 |
+
h = h + res_ffn.repeat_interleave(K, dim=0)
|
| 286 |
+
|
| 287 |
+
ponder_cost = (pc_block + pc_ffn) / 2.0 # 两个子层平均
|
| 288 |
+
|
| 289 |
+
# 存储 per-token E[K] 范围(诊断用,不影响计算图)
|
| 290 |
+
# ek_block/ek_ffn: (seq_len, batch), detached
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
all_ek = torch.cat([ek_block.flatten(), ek_ffn.flatten()])
|
| 293 |
+
self._ek_min = all_ek.min().item()
|
| 294 |
+
self._ek_max = all_ek.max().item()
|
| 295 |
+
|
| 296 |
+
return h, ponder_cost
|
| 297 |
+
|
| 298 |
+
def single_step_forward(self, h):
|
| 299 |
+
"""
|
| 300 |
+
单步前向传播:连续残差流。
|
| 301 |
+
|
| 302 |
+
注意:单步模式无法做动态 K 聚合(每步独立处理)。
|
| 303 |
+
训练和推理均使用 forward_parallel(含动态 K 聚合)。
|
| 304 |
+
此方法仅用于调试。
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
h: (batch, D) — 连续值输入
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
h: (batch, D) — 连续值输出
|
| 311 |
+
ponder_cost: scalar — 0.0(单步无 ponder cost)
|
| 312 |
+
"""
|
| 313 |
+
# 子层 1: SNNBlock — RMSNorm → PLIFNode(leak) → SNNBlock → out_proj → 残差
|
| 314 |
+
_ = self.input_neuron1(self.block_norm(h)) # 触发 PLIF 动力学,更新 .v
|
| 315 |
+
v_in = (1.0 - self.input_neuron1.beta) * self.input_neuron1.v # 膜电位泄漏量
|
| 316 |
+
cont_block = self.snn_block.single_step_forward(v_in)
|
| 317 |
+
res_block = self.block_out_proj(cont_block)
|
| 318 |
+
h = h + res_block - res_block.mean(dim=-1, keepdim=True)
|
| 319 |
+
|
| 320 |
+
# 子层 2: SNNFFN — RMSNorm → PLIFNode(leak) → SNNFFN → out_proj → 残差
|
| 321 |
+
_ = self.input_neuron2(self.ffn_norm(h))
|
| 322 |
+
v_in2 = (1.0 - self.input_neuron2.beta) * self.input_neuron2.v # 膜电位泄漏量
|
| 323 |
+
cont_ffn = self.snn_ffn.single_step_forward(v_in2)
|
| 324 |
+
res_ffn = self.ffn_out_proj(cont_ffn)
|
| 325 |
+
h = h + res_ffn - res_ffn.mean(dim=-1, keepdim=True)
|
| 326 |
+
|
| 327 |
+
return h, torch.tensor(0.0, device=h.device)
|
atomic_ops/snn_ffn.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SNNFFN: SNN 等价的 Feed-Forward Network
|
| 3 |
+
|
| 4 |
+
对标 Qwen3MLP 的 SwiGLU 结构:
|
| 5 |
+
Qwen3 MLP: down_proj( SiLU(gate_proj(x)) * up_proj(x) )
|
| 6 |
+
SNN FFN: down_proj( gate_V_post * up_V_post ) + skip
|
| 7 |
+
|
| 8 |
+
膜电位门控(对标 SiLU gating):
|
| 9 |
+
gate/up 神经元完整 PLIF 动力学(积分+阈值+重置),
|
| 10 |
+
输出膜电位 V_post 做连续乘法门控,替代 binary AND 门。
|
| 11 |
+
|
| 12 |
+
信号流:
|
| 13 |
+
x → gate_proj → gate_neuron → V_post_gate
|
| 14 |
+
x → up_proj → up_neuron → V_post_up
|
| 15 |
+
V_post_gate × V_post_up → gated
|
| 16 |
+
down_proj(gated) + skip_proj(x) → 连续输出
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from spikingjelly.activation_based import base, layer, surrogate
|
| 24 |
+
|
| 25 |
+
from .plif_node import PLIFNode
|
| 26 |
+
from .parallel_scan import plif_rowparam_forward
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SNNFFN(base.MemoryModule):
|
| 30 |
+
"""
|
| 31 |
+
SNN 等价的 Feed-Forward Network。
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
D: 可见维度(输入/输出 spike 维度)
|
| 35 |
+
D_ff: 中间层维度(对标 Qwen3 intermediate_size)
|
| 36 |
+
output_v_threshold: 输出神经元阈值
|
| 37 |
+
num_layers: 总层数,用于 down_proj 缩放
|
| 38 |
+
layer_idx: 当前层索引
|
| 39 |
+
surrogate_function: surrogate gradient 函数
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
D: int,
|
| 45 |
+
D_ff: int,
|
| 46 |
+
output_v_threshold: float = 0.3,
|
| 47 |
+
num_layers: int = 1,
|
| 48 |
+
layer_idx: int = 0,
|
| 49 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.D = D
|
| 53 |
+
self.D_ff = D_ff
|
| 54 |
+
|
| 55 |
+
# ====== 三条投影路径(对标 SwiGLU: gate_proj, up_proj, down_proj) ======
|
| 56 |
+
self.gate_proj = layer.Linear(D, D_ff, bias=False, step_mode='s')
|
| 57 |
+
self.up_proj = layer.Linear(D, D_ff, bias=False, step_mode='s')
|
| 58 |
+
self.down_proj = layer.Linear(D_ff, D, bias=False, step_mode='s')
|
| 59 |
+
|
| 60 |
+
# ====== 残差路径 ======
|
| 61 |
+
self.skip_proj = layer.Linear(D, D, bias=False, step_mode='s')
|
| 62 |
+
|
| 63 |
+
# ====== 神经元(D 维或 D_ff 维可学习 β 和 V_th) ======
|
| 64 |
+
# gate_neuron: 门控发放
|
| 65 |
+
self.gate_neuron = PLIFNode(
|
| 66 |
+
dim=D_ff,
|
| 67 |
+
init_tau=2.0,
|
| 68 |
+
v_threshold=output_v_threshold,
|
| 69 |
+
surrogate_function=surrogate_function,
|
| 70 |
+
)
|
| 71 |
+
# up_neuron: 值发放
|
| 72 |
+
self.up_neuron = PLIFNode(
|
| 73 |
+
dim=D_ff,
|
| 74 |
+
init_tau=2.0,
|
| 75 |
+
v_threshold=output_v_threshold,
|
| 76 |
+
surrogate_function=surrogate_function,
|
| 77 |
+
)
|
| 78 |
+
# ====== 参数初始化 ======
|
| 79 |
+
self._initialize_parameters(num_layers)
|
| 80 |
+
|
| 81 |
+
def _initialize_parameters(self, num_layers: int):
|
| 82 |
+
"""初始化投影权重。
|
| 83 |
+
|
| 84 |
+
- gate_proj, up_proj, skip_proj: Kaiming uniform
|
| 85 |
+
- down_proj: Kaiming uniform × 1/√(num_layers),防深层梯度爆炸
|
| 86 |
+
"""
|
| 87 |
+
for lin in [self.gate_proj, self.up_proj, self.skip_proj]:
|
| 88 |
+
nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5))
|
| 89 |
+
|
| 90 |
+
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
|
| 91 |
+
self.down_proj.weight.data.mul_(1.0 / math.sqrt(num_layers))
|
| 92 |
+
|
| 93 |
+
def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
"""
|
| 95 |
+
并行前向传播:使用 parallel scan 处理全序列。
|
| 96 |
+
|
| 97 |
+
优化:
|
| 98 |
+
- gate_proj + up_proj 合并为单次 matmul(2 launch → 1)
|
| 99 |
+
- gate + up PLIF scan: row-param kernel(无需 expand+contiguous beta/v_th)
|
| 100 |
+
- u_merged: 向量缩放替代 cat(1次 broadcast multiply 替代 2次 scale + 1次 cat)
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出
|
| 107 |
+
"""
|
| 108 |
+
TK, batch, D = spike_in_seq.shape
|
| 109 |
+
D_ff = self.D_ff
|
| 110 |
+
flat = spike_in_seq.reshape(TK * batch, D)
|
| 111 |
+
|
| 112 |
+
# ====== Phase 1: 批量投影(gate+up 合并为 1 次 matmul) ======
|
| 113 |
+
W_gate_up = torch.cat([self.gate_proj.weight, self.up_proj.weight], dim=0)
|
| 114 |
+
I_gate_up = F.linear(flat, W_gate_up).reshape(TK, batch, 2 * D_ff)
|
| 115 |
+
I_skip = F.linear(flat, self.skip_proj.weight).reshape(TK, batch, D)
|
| 116 |
+
|
| 117 |
+
# ====== Phase 2: Gate+Up 合并 PLIF scan(row-param kernel) ======
|
| 118 |
+
beta_gate = self.gate_neuron.beta # (D_ff,)
|
| 119 |
+
beta_up = self.up_neuron.beta # (D_ff,)
|
| 120 |
+
surr = self.gate_neuron.surrogate_function
|
| 121 |
+
|
| 122 |
+
# u_merged: 向量缩放(D_ff 维 β 直接 cat,无需 expand)
|
| 123 |
+
scale_row = torch.cat([1.0 - beta_gate, 1.0 - beta_up]) # (2*D_ff,)
|
| 124 |
+
u_merged = I_gate_up * scale_row # (TK, batch, 2*D_ff), broadcast
|
| 125 |
+
|
| 126 |
+
# beta_row / v_th_row: (batch, 2*D_ff) — D_ff 维可学习参数
|
| 127 |
+
beta_row = torch.cat([beta_gate, beta_up]) # (2*D_ff,)
|
| 128 |
+
beta_row = beta_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous()
|
| 129 |
+
|
| 130 |
+
v_th_row = torch.cat([self.gate_neuron.v_th, self.up_neuron.v_th]) # (2*D_ff,)
|
| 131 |
+
v_th_row = v_th_row.unsqueeze(0).expand(batch, 2 * D_ff).contiguous()
|
| 132 |
+
|
| 133 |
+
# v_init_merged: (batch, 2*D_ff)
|
| 134 |
+
v_init_gate = self.gate_neuron.v
|
| 135 |
+
if isinstance(v_init_gate, float):
|
| 136 |
+
v_init_gate = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype)
|
| 137 |
+
v_init_up = self.up_neuron.v
|
| 138 |
+
if isinstance(v_init_up, float):
|
| 139 |
+
v_init_up = torch.zeros(batch, D_ff, device=flat.device, dtype=flat.dtype)
|
| 140 |
+
v_init_merged = torch.cat([v_init_gate, v_init_up], dim=-1)
|
| 141 |
+
|
| 142 |
+
# Row-param PLIF scan: beta/v_th 从寄存器读取,不占显存带宽
|
| 143 |
+
spike_merged, V_post_merged = plif_rowparam_forward(
|
| 144 |
+
beta_row, u_merged, v_th_row, v_init_merged,
|
| 145 |
+
surrogate_function=surr,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# 膜电位泄漏量作为激活值: leak = (1-β) · V_post
|
| 149 |
+
gate_leak = V_post_merged[:, :, :D_ff] * (1.0 - beta_gate) # (TK, batch, D_ff)
|
| 150 |
+
up_leak = V_post_merged[:, :, D_ff:] * (1.0 - beta_up) # (TK, batch, D_ff)
|
| 151 |
+
self.gate_neuron.v = V_post_merged[-1, :, :D_ff].detach()
|
| 152 |
+
self.up_neuron.v = V_post_merged[-1, :, D_ff:].detach()
|
| 153 |
+
|
| 154 |
+
# ====== Phase 3: 连续门控(leak × leak,对标 SwiGLU)+ 降维 ======
|
| 155 |
+
gated = gate_leak * up_leak # (TK, batch, D_ff)
|
| 156 |
+
gated_flat = gated.reshape(TK * batch, D_ff)
|
| 157 |
+
I_out = F.linear(gated_flat, self.down_proj.weight).reshape(TK, batch, D) + I_skip
|
| 158 |
+
|
| 159 |
+
# output_neuron 已移除:连续值由层级 K 帧聚合处理
|
| 160 |
+
return I_out # (TK, batch, D), 连续值
|
| 161 |
+
|
| 162 |
+
def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
"""
|
| 164 |
+
单步前向传播。
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1}
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
continuous_out: 连续输出, shape (batch, D)
|
| 171 |
+
"""
|
| 172 |
+
# 门控路径 — 膜电位泄漏量激活
|
| 173 |
+
_ = self.gate_neuron(self.gate_proj(spike_in))
|
| 174 |
+
gate_leak = (1.0 - self.gate_neuron.beta) * self.gate_neuron.v # leak
|
| 175 |
+
|
| 176 |
+
# 值路径 — 膜电位泄漏量激活
|
| 177 |
+
_ = self.up_neuron(self.up_proj(spike_in))
|
| 178 |
+
up_leak = (1.0 - self.up_neuron.beta) * self.up_neuron.v # leak
|
| 179 |
+
|
| 180 |
+
# 连续门控(对标 SwiGLU)
|
| 181 |
+
gated = gate_leak * up_leak
|
| 182 |
+
|
| 183 |
+
# 降维 + 残差
|
| 184 |
+
I_out = self.down_proj(gated) + self.skip_proj(spike_in) # R^D
|
| 185 |
+
return I_out # 连续值
|
config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "neuronspark",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"NeuronSparkForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_neuronspark.NeuronSparkConfig",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_neuronspark.NeuronSparkForCausalLM"
|
| 9 |
+
},
|
| 10 |
+
"vocab_size": 6144,
|
| 11 |
+
"D": 896,
|
| 12 |
+
"N": 8,
|
| 13 |
+
"K": 16,
|
| 14 |
+
"num_layers": 20,
|
| 15 |
+
"D_ff": 2688,
|
| 16 |
+
"v_th_min": 0.1,
|
| 17 |
+
"torch_dtype": "float32",
|
| 18 |
+
"transformers_version": "4.52.0",
|
| 19 |
+
"_training_step": 85000,
|
| 20 |
+
"_tokens_seen": 203910528
|
| 21 |
+
}
|
configuration.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"framework": "pytorch", "task": "other"}
|
configuration_neuronspark.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""NeuronSpark 模型配置。"""
|
| 2 |
+
|
| 3 |
+
from transformers import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class NeuronSparkConfig(PretrainedConfig):
|
| 7 |
+
"""
|
| 8 |
+
SNN 隐状态空间语言模型配置。
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
vocab_size: 词表大小
|
| 12 |
+
D: 隐层维度
|
| 13 |
+
N: 状态扩展因子(每通道隐神经元数)
|
| 14 |
+
K: 每 token 最大 SNN 时间步(PonderNet 动态决定有效步数)
|
| 15 |
+
num_layers: SNN 解码层数
|
| 16 |
+
D_ff: FFN 中间层维度
|
| 17 |
+
v_th_min: 动态阈值下限
|
| 18 |
+
"""
|
| 19 |
+
model_type = "neuronspark"
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
vocab_size=6144,
|
| 24 |
+
D=896,
|
| 25 |
+
N=8,
|
| 26 |
+
K=16,
|
| 27 |
+
num_layers=20,
|
| 28 |
+
D_ff=2688,
|
| 29 |
+
v_th_min=0.1,
|
| 30 |
+
**kwargs,
|
| 31 |
+
):
|
| 32 |
+
self.D = D
|
| 33 |
+
self.N = N
|
| 34 |
+
self.K = K
|
| 35 |
+
self.num_layers = num_layers
|
| 36 |
+
self.D_ff = D_ff
|
| 37 |
+
self.v_th_min = v_th_min
|
| 38 |
+
super().__init__(vocab_size=vocab_size, **kwargs)
|
model.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SNNLanguageModel: SNN 隐状态空间语言模型(全膜电位 + 动态 K)
|
| 3 |
+
|
| 4 |
+
架构(三段式):
|
| 5 |
+
model.encode(token_ids) → h_seq # 输入: embed → repeat K 次(可微分)
|
| 6 |
+
model.snn_forward(h_seq) → h_out, pc # SNN 核心: 20 层,全膜电位 + 动态 K 聚合
|
| 7 |
+
model.decode(h_out, seq) → logits # 输出: output_neuron(V_post) → K帧mean → proj → logits
|
| 8 |
+
|
| 9 |
+
核心设计:
|
| 10 |
+
1. 膜电位泄漏量:PLIFNode 输出 (1-β)·V_post(泄漏量),自然强调快响应神经元
|
| 11 |
+
2. 动态 K:PonderNet 自适应停止,不同 token 不同有效步数
|
| 12 |
+
- 每层每子层学习 halt_proj(D→1),从 SNN 输出逐步计算停止概率
|
| 13 |
+
- 几何分布权重加权聚合,替代 uniform mean
|
| 14 |
+
- ponder_cost 正则化鼓励早停
|
| 15 |
+
|
| 16 |
+
数学原理见 SNN_SELECTIVE_STATE_SPACE.md。
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from spikingjelly.activation_based import functional, surrogate
|
| 27 |
+
from torch.utils.checkpoint import checkpoint
|
| 28 |
+
|
| 29 |
+
from atomic_ops import SNNDecoderLayer
|
| 30 |
+
from atomic_ops.plif_node import PLIFNode
|
| 31 |
+
from atomic_ops.rms_norm import RMSNorm
|
| 32 |
+
from atomic_ops.parallel_scan import plif_rowparam_forward
|
| 33 |
+
# fp16_encode/fp16_decode 已移除: 全膜电位架构不需要 spike 编解码
|
| 34 |
+
from atomic_ops.lateral_inhibition import LateralInhibition
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class SNNModelOutput:
|
| 39 |
+
"""模型输出容器,对齐教程 CausalLMOutputWithPast 接口。"""
|
| 40 |
+
last_loss: Optional[torch.Tensor] = None
|
| 41 |
+
logits: Optional[torch.Tensor] = None
|
| 42 |
+
ponder_cost: Optional[torch.Tensor] = None # 动态 K: 平均期望步数
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SNNLanguageModel(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
从零训练的 SNN 隐状态空间语言模型(parallel scan)。
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
vocab_size: 词表大小(默认 6144,自训练 BPE)
|
| 51 |
+
D: 可见维度
|
| 52 |
+
N: 状态扩展因子
|
| 53 |
+
K: 每 token 最大 SNN 时间步数(K_max)。PonderNet 动态决定有效步数 ∈ [1, K]。
|
| 54 |
+
K 越大 → 复杂 token 可用更多步数,但计算量和显存线性增长。
|
| 55 |
+
num_layers: SNN 解码层数
|
| 56 |
+
D_ff: FFN 中间层维度
|
| 57 |
+
v_th_min: 动态阈值下限
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
vocab_size: int = 6144,
|
| 63 |
+
D: int = 1024,
|
| 64 |
+
N: int = 8,
|
| 65 |
+
K: int = 32,
|
| 66 |
+
num_layers: int = 20,
|
| 67 |
+
D_ff: int = 3072,
|
| 68 |
+
v_th_min: float = 0.1,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.vocab_size = vocab_size
|
| 72 |
+
self.D = D
|
| 73 |
+
self.N = N
|
| 74 |
+
self.K = K
|
| 75 |
+
self.num_layers = num_layers
|
| 76 |
+
self.D_ff = D_ff
|
| 77 |
+
|
| 78 |
+
# ====== Embedding + Norm(全部可训练)======
|
| 79 |
+
self.embed_tokens = nn.Embedding(vocab_size, D)
|
| 80 |
+
self.norm = LateralInhibition(D)
|
| 81 |
+
|
| 82 |
+
# ====== 解码投影 ======
|
| 83 |
+
self.decode_proj = nn.Linear(D, D)
|
| 84 |
+
|
| 85 |
+
# ====== 输出 RMSNorm + 输出神经元 ======
|
| 86 |
+
self.output_norm = RMSNorm(D)
|
| 87 |
+
self.output_neuron = PLIFNode(
|
| 88 |
+
dim=D,
|
| 89 |
+
init_tau=2.0,
|
| 90 |
+
v_threshold=0.3,
|
| 91 |
+
surrogate_function=surrogate.Sigmoid(alpha=4.0),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# ====== SNN Decoder Layers ======
|
| 95 |
+
self.layers = nn.ModuleList([
|
| 96 |
+
SNNDecoderLayer(
|
| 97 |
+
D=D, N=N, D_ff=D_ff, v_th_min=v_th_min,
|
| 98 |
+
ffn_v_threshold=0.15,
|
| 99 |
+
K=K,
|
| 100 |
+
num_layers=num_layers,
|
| 101 |
+
layer_idx=i,
|
| 102 |
+
)
|
| 103 |
+
for i in range(num_layers)
|
| 104 |
+
])
|
| 105 |
+
|
| 106 |
+
self._init_weights()
|
| 107 |
+
|
| 108 |
+
def _init_weights(self):
|
| 109 |
+
"""初始化所有可训练权重(从零训练)。"""
|
| 110 |
+
nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=0.02)
|
| 111 |
+
nn.init.xavier_uniform_(self.decode_proj.weight)
|
| 112 |
+
nn.init.zeros_(self.decode_proj.bias)
|
| 113 |
+
|
| 114 |
+
def encode(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
"""输入边界:token_ids → 连续值序列。
|
| 116 |
+
|
| 117 |
+
Embedding lookup,每 token 重复 K 次作为 SNN 时间步输入。
|
| 118 |
+
梯度可通过 embedding 直接反传。
|
| 119 |
+
|
| 120 |
+
Returns: (seq_len*K, batch, D), 连续值
|
| 121 |
+
"""
|
| 122 |
+
emb = self.embed_tokens(token_ids) # (batch, seq_len, D)
|
| 123 |
+
batch, seq_len, D = emb.shape
|
| 124 |
+
# 每 token 重复 K 次: (batch, seq_len, D) → (batch, seq_len*K, D) → (TK, batch, D)
|
| 125 |
+
emb_k = emb.unsqueeze(2).expand(-1, -1, self.K, -1).reshape(batch, seq_len * self.K, D)
|
| 126 |
+
return emb_k.permute(1, 0, 2).contiguous() # (TK, batch, D)
|
| 127 |
+
|
| 128 |
+
def snn_forward(self, spike_seq: torch.Tensor):
|
| 129 |
+
"""SNN 核心:spike_seq → (h_out, ponder_cost)。
|
| 130 |
+
|
| 131 |
+
纯 SNN 层计算,带梯度检查点。
|
| 132 |
+
每层返回 (h, ponder_cost),ponder_cost 作为 checkpoint 输出保留梯度图。
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
h: (seq_len*K, batch, D), 连续值
|
| 136 |
+
total_ponder_cost: scalar, 所有层平均期望步数
|
| 137 |
+
"""
|
| 138 |
+
h = spike_seq
|
| 139 |
+
ponder_costs = []
|
| 140 |
+
|
| 141 |
+
def _layer_forward(layer_mod, x):
|
| 142 |
+
functional.reset_net(layer_mod)
|
| 143 |
+
return layer_mod.forward_parallel(x) # returns (h, ponder_cost)
|
| 144 |
+
|
| 145 |
+
for layer_module in self.layers:
|
| 146 |
+
h, pc = checkpoint(
|
| 147 |
+
_layer_forward, layer_module, h,
|
| 148 |
+
use_reentrant=False,
|
| 149 |
+
)
|
| 150 |
+
ponder_costs.append(pc)
|
| 151 |
+
|
| 152 |
+
total_ponder_cost = sum(ponder_costs) / len(ponder_costs)
|
| 153 |
+
return h, total_ponder_cost
|
| 154 |
+
|
| 155 |
+
def _output_neuron_parallel(self, h: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
"""输出 PLIF 神经元的 parallel scan 前向:连续 h → 膜电位泄漏量。
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
h: (TK, batch, D) 连续值(SNN 最后一层输出)
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
leak: (TK, batch, D) 膜电位泄漏量 (1-β)·V_post
|
| 163 |
+
"""
|
| 164 |
+
TK, batch, D = h.shape
|
| 165 |
+
|
| 166 |
+
beta = self.output_neuron.beta # (D,)
|
| 167 |
+
u = (1.0 - beta) * h # PLIF: u = (1-β) · x
|
| 168 |
+
|
| 169 |
+
v_init = self.output_neuron.v
|
| 170 |
+
if isinstance(v_init, float):
|
| 171 |
+
v_init = torch.zeros(batch, D, device=h.device, dtype=h.dtype)
|
| 172 |
+
|
| 173 |
+
beta_row = beta.unsqueeze(0).expand(batch, D).contiguous()
|
| 174 |
+
v_th_row = self.output_neuron.v_th.unsqueeze(0).expand(batch, D).contiguous()
|
| 175 |
+
|
| 176 |
+
spike, V_post = plif_rowparam_forward(
|
| 177 |
+
beta_row, u, v_th_row, v_init,
|
| 178 |
+
surrogate_function=self.output_neuron.surrogate_function,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
self.output_neuron.v = V_post[-1].detach()
|
| 182 |
+
return (1.0 - beta) * V_post # 膜电位泄漏量
|
| 183 |
+
|
| 184 |
+
def decode(self, h_out: torch.Tensor, seq_len: int) -> torch.Tensor:
|
| 185 |
+
"""输出边界:连续 h → 输出神经元(V_post) → K 帧聚合 → logits。
|
| 186 |
+
|
| 187 |
+
梯度流: loss → logits → norm → decode_proj → K帧mean
|
| 188 |
+
→ V_post(output_neuron) → h_out → SNN layers
|
| 189 |
+
|
| 190 |
+
Returns: (batch, seq_len, vocab_size)
|
| 191 |
+
"""
|
| 192 |
+
h_out = self.output_norm(h_out) # RMSNorm: 控制 scale
|
| 193 |
+
v_out = self._output_neuron_parallel(h_out) # (TK, batch, D), V_post 膜电位
|
| 194 |
+
# K 帧聚合: (TK, batch, D) → (seq_len, K, batch, D) → mean → (seq_len, batch, D)
|
| 195 |
+
decoded = v_out.view(seq_len, self.K, -1, self.D).mean(dim=1)
|
| 196 |
+
decoded = decoded.permute(1, 0, 2) # (batch, seq_len, D)
|
| 197 |
+
h = self.decode_proj(decoded) # (batch, seq_len, D)
|
| 198 |
+
h = self.norm(h) # (batch, seq_len, D)
|
| 199 |
+
return F.linear(h, self.embed_tokens.weight) # (batch, seq_len, vocab)
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def generate(
|
| 203 |
+
self,
|
| 204 |
+
prompt_ids: torch.Tensor,
|
| 205 |
+
max_new_tokens: int,
|
| 206 |
+
temperature: float = 1.0,
|
| 207 |
+
top_k: int = 50,
|
| 208 |
+
eos_token_id: Optional[int] = None,
|
| 209 |
+
) -> torch.Tensor:
|
| 210 |
+
"""
|
| 211 |
+
自回归生成(SNN 神经元状态跨 token 连续维护)。
|
| 212 |
+
|
| 213 |
+
1. Prefill: forward_parallel 并行处理 prompt,建立所有神经元 V 状态
|
| 214 |
+
2. Autoregressive: 逐 token 生成,每 token 用 forward_parallel 处理 K 帧
|
| 215 |
+
复用 Triton parallel scan kernel,神经元 V 状态跨 token 连续传递
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
prompt_ids: (batch, prompt_len) token IDs
|
| 219 |
+
max_new_tokens: 最大生成 token 数
|
| 220 |
+
temperature: 采样温度(<=0 = greedy)
|
| 221 |
+
top_k: top-k 采样(None/0 = 不限制)
|
| 222 |
+
eos_token_id: 遇到此 token 停止生成
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
(batch, prompt_len + generated_len) 完整序列
|
| 226 |
+
"""
|
| 227 |
+
batch, prompt_len = prompt_ids.shape
|
| 228 |
+
|
| 229 |
+
# 重置所有神经元(新序列的初始条件 V=0)
|
| 230 |
+
for layer_module in self.layers:
|
| 231 |
+
functional.reset_net(layer_module)
|
| 232 |
+
functional.reset_net(self.output_neuron)
|
| 233 |
+
|
| 234 |
+
# ====== Prefill: parallel 处理整个 prompt ======
|
| 235 |
+
h_seq = self.encode(prompt_ids) # (prompt_len*K, batch, D), 连续值
|
| 236 |
+
h = h_seq
|
| 237 |
+
for layer_module in self.layers:
|
| 238 |
+
h, _ = layer_module.forward_parallel(h) # 推理忽略 ponder_cost
|
| 239 |
+
# 此时所有层的所有神经元 .v 状态 = prompt 末尾状态
|
| 240 |
+
|
| 241 |
+
logits = self.decode(h, prompt_len)
|
| 242 |
+
|
| 243 |
+
# 采样第一个新 token
|
| 244 |
+
next_token = self._sample(logits[:, -1, :], temperature, top_k)
|
| 245 |
+
generated = [next_token]
|
| 246 |
+
|
| 247 |
+
# ====== Autoregressive: 逐 token,forward_parallel 处理 K 帧 ======
|
| 248 |
+
for _ in range(max_new_tokens - 1):
|
| 249 |
+
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 250 |
+
break
|
| 251 |
+
|
| 252 |
+
# 编码单 token → K 帧连续值(复用 encode)
|
| 253 |
+
frames = self.encode(next_token) # (K, batch, D)
|
| 254 |
+
|
| 255 |
+
# K 帧通过 SNN — 不 reset,神经元 .v 跨 token 连续传递
|
| 256 |
+
h = frames
|
| 257 |
+
for layer_module in self.layers:
|
| 258 |
+
h, _ = layer_module.forward_parallel(h)
|
| 259 |
+
|
| 260 |
+
logits = self.decode(h, 1)
|
| 261 |
+
|
| 262 |
+
next_token = self._sample(logits[:, -1, :], temperature, top_k)
|
| 263 |
+
generated.append(next_token)
|
| 264 |
+
|
| 265 |
+
return torch.cat([prompt_ids, torch.cat(generated, dim=1)], dim=1)
|
| 266 |
+
|
| 267 |
+
def _sample(self, logits: torch.Tensor, temperature: float = 1.0, top_k: int = None) -> torch.Tensor:
|
| 268 |
+
"""从 logits 采样(temperature + top-k)。
|
| 269 |
+
|
| 270 |
+
Returns: (batch, 1)
|
| 271 |
+
"""
|
| 272 |
+
if temperature <= 0:
|
| 273 |
+
return logits.argmax(dim=-1, keepdim=True)
|
| 274 |
+
logits = logits / temperature
|
| 275 |
+
if top_k is not None and top_k > 0:
|
| 276 |
+
top_k = min(top_k, logits.size(-1))
|
| 277 |
+
v, _ = torch.topk(logits, top_k)
|
| 278 |
+
logits[logits < v[:, [-1]]] = float('-inf')
|
| 279 |
+
probs = F.softmax(logits, dim=-1)
|
| 280 |
+
return torch.multinomial(probs, num_samples=1)
|
| 281 |
+
|
| 282 |
+
def forward(
|
| 283 |
+
self,
|
| 284 |
+
token_ids: torch.Tensor,
|
| 285 |
+
target_ids: torch.Tensor = None,
|
| 286 |
+
) -> SNNModelOutput:
|
| 287 |
+
"""
|
| 288 |
+
前向传播(全膜电位 + 动态 K)。
|
| 289 |
+
|
| 290 |
+
encode → h_seq # 输入(embed repeat K 次,可微分)
|
| 291 |
+
snn_forward → h_out, pc # SNN 核心(全膜电位 + 动态 K 聚合)
|
| 292 |
+
decode → logits # 输出(V_post → K帧mean → proj → logits)
|
| 293 |
+
|
| 294 |
+
梯度流:
|
| 295 |
+
embed_tokens → repeat K → SNN layers(V_post + 动态K)
|
| 296 |
+
→ output_neuron(V_post) → K帧mean → decode_proj → logits(tied head)
|
| 297 |
+
ponder_cost: 动态 K 正则化,鼓励用更少步数处理简单 token
|
| 298 |
+
"""
|
| 299 |
+
batch, seq_len = token_ids.shape
|
| 300 |
+
|
| 301 |
+
# 重置所有神经元状态
|
| 302 |
+
for layer_module in self.layers:
|
| 303 |
+
functional.reset_net(layer_module)
|
| 304 |
+
functional.reset_net(self.output_neuron)
|
| 305 |
+
|
| 306 |
+
# 三段式
|
| 307 |
+
spike_seq = self.encode(token_ids) # 输入边界
|
| 308 |
+
h_out, ponder_cost = self.snn_forward(spike_seq) # SNN 核心 + ponder cost
|
| 309 |
+
logits = self.decode(h_out, seq_len) # 输出边界
|
| 310 |
+
|
| 311 |
+
if target_ids is not None:
|
| 312 |
+
logits_flat = logits.reshape(-1, self.vocab_size)
|
| 313 |
+
targets_flat = target_ids.reshape(-1)
|
| 314 |
+
self.last_loss = F.cross_entropy(
|
| 315 |
+
logits_flat, targets_flat,
|
| 316 |
+
ignore_index=0, reduction='none',
|
| 317 |
+
)
|
| 318 |
+
return SNNModelOutput(
|
| 319 |
+
last_loss=self.last_loss,
|
| 320 |
+
ponder_cost=ponder_cost,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
return SNNModelOutput(logits=logits, ponder_cost=ponder_cost)
|
| 324 |
+
|
| 325 |
+
def compensate_modulation_gradients(self, max_comp: float = 100.0):
|
| 326 |
+
"""
|
| 327 |
+
Natural Gradient 补偿(两阶段)。
|
| 328 |
+
|
| 329 |
+
Phase 1: Sigmoid/softplus 饱和补偿
|
| 330 |
+
β = sigmoid(b_beta), sigmoid 在高 β 区(β=0.99, sigmoid'=0.01)梯度衰减 100x。
|
| 331 |
+
补偿: grad /= activation'(b),等价于在 β/α 空间做梯度下降。
|
| 332 |
+
|
| 333 |
+
Phase 2: 层间梯度均衡
|
| 334 |
+
残差链反向传播每层放大 ~1.17×,20 层累积 ~20× L0/L19 比。
|
| 335 |
+
深层选择性参数(b_beta/b_alpha/b_th)梯度被压制,无法有效学习。
|
| 336 |
+
修复: 将每层调制参数梯度 norm 归一化到所有层的几何均值。
|
| 337 |
+
|
| 338 |
+
调用时机: scaler.unscale_(optimizer) 之后、clip_grad_norm_ 之前。
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
max_comp: 补偿因子上限(防止极端值导致不稳定)
|
| 342 |
+
"""
|
| 343 |
+
# ====== Phase 1: Sigmoid/softplus 饱和补偿 ======
|
| 344 |
+
for layer_module in self.layers:
|
| 345 |
+
block = layer_module.snn_block
|
| 346 |
+
|
| 347 |
+
# b_beta: sigmoid 饱和补偿
|
| 348 |
+
# sigmoid'(z) = sigmoid(z) · (1 - sigmoid(z)) = β · (1-β)
|
| 349 |
+
if block.b_beta.grad is not None:
|
| 350 |
+
with torch.no_grad():
|
| 351 |
+
beta = torch.sigmoid(block.b_beta.data)
|
| 352 |
+
sigmoid_deriv = (beta * (1.0 - beta)).clamp(min=1.0 / max_comp)
|
| 353 |
+
block.b_beta.grad.div_(sigmoid_deriv)
|
| 354 |
+
|
| 355 |
+
# b_alpha: softplus 补偿(较温和,softplus'(z) = sigmoid(z))
|
| 356 |
+
if block.b_alpha.grad is not None:
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
softplus_deriv = torch.sigmoid(block.b_alpha.data).clamp(min=0.1)
|
| 359 |
+
block.b_alpha.grad.div_(softplus_deriv)
|
| 360 |
+
|
| 361 |
+
# b_th: |·| 导数为 ±1,无衰减,不需要补偿
|
| 362 |
+
|
| 363 |
+
# ====== Phase 2: 层间梯度均衡 ======
|
| 364 |
+
# 残差链 h = h + sublayer(h) 的反向路径 ∂h_{l+1}/∂h_l = I + ∂sublayer/∂h_l
|
| 365 |
+
# 每层放大 ~1.17×, 20 层累积 ~20× → L0 梯度远大于 L19
|
| 366 |
+
# 用几何均值归一化每层调制参数梯度 norm,消除残差放大效应
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
for param_name in ['b_beta', 'b_alpha', 'b_th']:
|
| 369 |
+
norms = []
|
| 370 |
+
params_list = []
|
| 371 |
+
for layer_module in self.layers:
|
| 372 |
+
p = getattr(layer_module.snn_block, param_name)
|
| 373 |
+
if p.grad is not None:
|
| 374 |
+
n = p.grad.norm().item()
|
| 375 |
+
if n > 1e-12:
|
| 376 |
+
norms.append(n)
|
| 377 |
+
params_list.append(p)
|
| 378 |
+
|
| 379 |
+
if len(norms) >= 2:
|
| 380 |
+
# 几何均值: exp(mean(log(norms))) — 对数尺度均衡,不受极端值影响
|
| 381 |
+
log_mean = sum(math.log(n) for n in norms) / len(norms)
|
| 382 |
+
geo_mean = math.exp(log_mean)
|
| 383 |
+
for p, n in zip(params_list, norms):
|
| 384 |
+
scale = geo_mean / n
|
| 385 |
+
scale = max(min(scale, max_comp), 1.0 / max_comp)
|
| 386 |
+
p.grad.mul_(scale)
|
| 387 |
+
|
| 388 |
+
def get_param_groups(self) -> dict[str, list[nn.Parameter]]:
|
| 389 |
+
"""
|
| 390 |
+
按功能分组的可训练参数。
|
| 391 |
+
"""
|
| 392 |
+
groups = {
|
| 393 |
+
'embedding': [self.embed_tokens.weight],
|
| 394 |
+
'norm': [self.norm.gain],
|
| 395 |
+
'decode': list(self.decode_proj.parameters()),
|
| 396 |
+
# 输出神经元
|
| 397 |
+
'output_neuron': [self.output_neuron.w, self.output_neuron.v_th],
|
| 398 |
+
# RMSNorm(Pre-LN 分支归一化)
|
| 399 |
+
'rms_norms': [self.output_norm.weight],
|
| 400 |
+
# 残差流组件
|
| 401 |
+
'residual_projs': [],
|
| 402 |
+
'input_neurons': [],
|
| 403 |
+
# 动态 K: 停止投影
|
| 404 |
+
'halt_projs': [],
|
| 405 |
+
# SNNBlock 参数
|
| 406 |
+
'W_in': [],
|
| 407 |
+
'W_beta': [],
|
| 408 |
+
'W_alpha': [],
|
| 409 |
+
'W_th': [],
|
| 410 |
+
'W_gate': [],
|
| 411 |
+
'W_skip': [],
|
| 412 |
+
'W_out': [],
|
| 413 |
+
'b_beta': [],
|
| 414 |
+
'b_alpha': [],
|
| 415 |
+
'b_th': [],
|
| 416 |
+
'block_output_neuron': [],
|
| 417 |
+
# SNNFFN 参数
|
| 418 |
+
'ffn_gate_proj': [],
|
| 419 |
+
'ffn_up_proj': [],
|
| 420 |
+
'ffn_down_proj': [],
|
| 421 |
+
'ffn_skip_proj': [],
|
| 422 |
+
'ffn_neurons': [],
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
for layer_module in self.layers:
|
| 426 |
+
block = layer_module.snn_block
|
| 427 |
+
ffn = layer_module.snn_ffn
|
| 428 |
+
|
| 429 |
+
# 残差流组件
|
| 430 |
+
groups['residual_projs'].extend([
|
| 431 |
+
layer_module.block_out_proj.weight,
|
| 432 |
+
layer_module.ffn_out_proj.weight,
|
| 433 |
+
])
|
| 434 |
+
groups['input_neurons'].extend([
|
| 435 |
+
layer_module.input_neuron1.w,
|
| 436 |
+
layer_module.input_neuron1.v_th,
|
| 437 |
+
layer_module.input_neuron2.w,
|
| 438 |
+
layer_module.input_neuron2.v_th,
|
| 439 |
+
])
|
| 440 |
+
groups['rms_norms'].extend([
|
| 441 |
+
layer_module.block_norm.weight,
|
| 442 |
+
layer_module.ffn_norm.weight,
|
| 443 |
+
])
|
| 444 |
+
|
| 445 |
+
# 动态 K: 停止投影参数
|
| 446 |
+
groups['halt_projs'].extend(list(layer_module.block_halt.parameters()))
|
| 447 |
+
groups['halt_projs'].extend(list(layer_module.ffn_halt.parameters()))
|
| 448 |
+
|
| 449 |
+
# SNNBlock 参数
|
| 450 |
+
groups['W_in'].append(block.W_in.weight)
|
| 451 |
+
groups['W_beta'].extend([block.W_beta_x.weight])
|
| 452 |
+
groups['W_alpha'].extend([block.W_alpha_x.weight])
|
| 453 |
+
groups['W_th'].extend([block.W_th_x.weight])
|
| 454 |
+
groups['W_gate'].append(block.W_gate.weight)
|
| 455 |
+
groups['W_skip'].append(block.W_skip.weight)
|
| 456 |
+
groups['W_out'].append(block.W_out.weight)
|
| 457 |
+
groups['b_beta'].append(block.b_beta)
|
| 458 |
+
groups['b_alpha'].append(block.b_alpha)
|
| 459 |
+
groups['b_th'].append(block.b_th)
|
| 460 |
+
|
| 461 |
+
# SNNFFN 参数
|
| 462 |
+
groups['ffn_gate_proj'].append(ffn.gate_proj.weight)
|
| 463 |
+
groups['ffn_up_proj'].append(ffn.up_proj.weight)
|
| 464 |
+
groups['ffn_down_proj'].append(ffn.down_proj.weight)
|
| 465 |
+
groups['ffn_skip_proj'].append(ffn.skip_proj.weight)
|
| 466 |
+
groups['ffn_neurons'].extend([
|
| 467 |
+
ffn.gate_neuron.w, ffn.gate_neuron.v_th,
|
| 468 |
+
ffn.up_neuron.w, ffn.up_neuron.v_th,
|
| 469 |
+
])
|
| 470 |
+
|
| 471 |
+
return groups
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:85a3b5312303134fced86a9d352175b91682176b6dea6af176f64bdaf6fc4b57
|
| 3 |
+
size 3496634368
|
modeling_neuronspark.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NeuronSpark: SNN 隐状态空间语言模型 — HuggingFace 接口
|
| 3 |
+
|
| 4 |
+
用法:
|
| 5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 6 |
+
|
| 7 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 8 |
+
"checkpoints_sft/", trust_remote_code=True,
|
| 9 |
+
)
|
| 10 |
+
tokenizer = AutoTokenizer.from_pretrained("checkpoints_sft/")
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from transformers import PreTrainedModel, GenerationMixin
|
| 17 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 18 |
+
|
| 19 |
+
from configuration_neuronspark import NeuronSparkConfig
|
| 20 |
+
from model import SNNLanguageModel
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class NeuronSparkForCausalLM(PreTrainedModel, GenerationMixin):
|
| 24 |
+
"""
|
| 25 |
+
SNN 语言模型 — CausalLM 接口。
|
| 26 |
+
|
| 27 |
+
封装 SNNLanguageModel,提供 HuggingFace 标准接口:
|
| 28 |
+
- forward(input_ids, labels) → CausalLMOutputWithPast
|
| 29 |
+
- generate() 支持(通过 GenerationMixin)
|
| 30 |
+
"""
|
| 31 |
+
config_class = NeuronSparkConfig
|
| 32 |
+
supports_gradient_checkpointing = True
|
| 33 |
+
|
| 34 |
+
def __init__(self, config: NeuronSparkConfig):
|
| 35 |
+
super().__init__(config)
|
| 36 |
+
self.model = SNNLanguageModel(
|
| 37 |
+
vocab_size=config.vocab_size,
|
| 38 |
+
D=config.D,
|
| 39 |
+
N=config.N,
|
| 40 |
+
K=config.K,
|
| 41 |
+
num_layers=config.num_layers,
|
| 42 |
+
D_ff=config.D_ff,
|
| 43 |
+
v_th_min=config.v_th_min,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def get_input_embeddings(self):
|
| 47 |
+
return self.model.embed_tokens
|
| 48 |
+
|
| 49 |
+
def set_input_embeddings(self, value):
|
| 50 |
+
self.model.embed_tokens = value
|
| 51 |
+
|
| 52 |
+
def get_output_embeddings(self):
|
| 53 |
+
# tied head: 输出复用 embed_tokens.weight
|
| 54 |
+
return self.model.embed_tokens
|
| 55 |
+
|
| 56 |
+
def forward(
|
| 57 |
+
self,
|
| 58 |
+
input_ids: torch.Tensor,
|
| 59 |
+
labels: Optional[torch.Tensor] = None,
|
| 60 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 61 |
+
**kwargs,
|
| 62 |
+
) -> CausalLMOutputWithPast:
|
| 63 |
+
"""
|
| 64 |
+
前向传播。
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
input_ids: (batch, seq_len) token IDs
|
| 68 |
+
labels: (batch, seq_len) 目标 token IDs(可选,用于计算 loss)
|
| 69 |
+
attention_mask: 兼容参数(SNN 无 attention,忽略)
|
| 70 |
+
"""
|
| 71 |
+
if labels is not None:
|
| 72 |
+
out = self.model(input_ids, target_ids=labels)
|
| 73 |
+
# 计算 masked loss
|
| 74 |
+
loss_mask = (labels != 0).float().view(-1)
|
| 75 |
+
loss = (out.last_loss * loss_mask).sum() / loss_mask.sum()
|
| 76 |
+
# 加 ponder cost
|
| 77 |
+
if out.ponder_cost is not None:
|
| 78 |
+
loss = loss + 0.01 * out.ponder_cost
|
| 79 |
+
return CausalLMOutputWithPast(loss=loss)
|
| 80 |
+
else:
|
| 81 |
+
out = self.model(input_ids)
|
| 82 |
+
return CausalLMOutputWithPast(logits=out.logits)
|
| 83 |
+
|
| 84 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 85 |
+
"""generate() 所需的输入准备。"""
|
| 86 |
+
return {"input_ids": input_ids}
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
def generate(
|
| 90 |
+
self,
|
| 91 |
+
input_ids: torch.Tensor,
|
| 92 |
+
max_new_tokens: int = 256,
|
| 93 |
+
temperature: float = 1.0,
|
| 94 |
+
top_k: int = 50,
|
| 95 |
+
eos_token_id: Optional[int] = None,
|
| 96 |
+
**kwargs,
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
+
"""
|
| 99 |
+
自回归生成(直接调用 SNN 的 generate 方法)。
|
| 100 |
+
"""
|
| 101 |
+
return self.model.generate(
|
| 102 |
+
prompt_ids=input_ids,
|
| 103 |
+
max_new_tokens=max_new_tokens,
|
| 104 |
+
temperature=temperature,
|
| 105 |
+
top_k=top_k,
|
| 106 |
+
eos_token_id=eos_token_id,
|
| 107 |
+
)
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<|im_start|>",
|
| 3 |
+
"eos_token": "<|im_end|>",
|
| 4 |
+
"unk_token": "<unk>",
|
| 5 |
+
"pad_token": "<|im_end|>",
|
| 6 |
+
"additional_special_tokens": [
|
| 7 |
+
"<s>",
|
| 8 |
+
"</s>"
|
| 9 |
+
]
|
| 10 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": false,
|
| 5 |
+
"bos_token": "<|im_start|>",
|
| 6 |
+
"eos_token": "<|im_end|>",
|
| 7 |
+
"pad_token": "<|im_end|>",
|
| 8 |
+
"unk_token": "<unk>",
|
| 9 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 10 |
+
"clean_up_tokenization_spaces": false,
|
| 11 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 12 |
+
"chat_template": "{% for message in messages %}{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
| 13 |
+
}
|