seconds-0 commited on
Commit
4303959
·
verified ·
1 Parent(s): 117a62c

NSA 117M initial export

Browse files
LICENSE ADDED
@@ -0,0 +1 @@
 
 
1
+ Apache-2.0
README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - nsa
7
+ - sparse-attention
8
+ - 117m
9
+ datasets:
10
+ - fineweb-edu
11
+ library_name: transformers
12
+ pipeline_tag: text-generation
13
+ base_model: byte-256
14
+ ---
15
+
16
+ # NSA 117M (FineWeb-Edu) — Remote Code
17
+
18
+ This repository contains a 117M NSA decoder-only model with remote code. It exposes `NSAConfig` and `NSAForCausalLM` so you can load via:
19
+
20
+ ```python
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+ m = AutoModelForCausalLM.from_pretrained("seconds-0/nsa-117m-byte", trust_remote_code=True)
23
+ t = AutoTokenizer.from_pretrained("seconds-0/nsa-117m-byte")
24
+ out = m.generate(**t("Hello", return_tensors="pt"), max_new_tokens=16)
25
+ ```
26
+
27
+ ## What is NSA
28
+
29
+ Native Sparse Attention (NSA) combines three branches — compressed (cmp), selected (sel), and sliding window (win) — mixed by a learned gate. The 117M configuration uses SDPA everywhere and keeps strict causality.
30
+
31
+ Architecture (overview):
32
+ - cmp: compressed blocks (tile length l, stride d) attended with causal masks
33
+ - sel: top-n selection over blockized keys (block l′, n ranges per step)
34
+ - win: sliding window attention of size w
35
+ - gate: small MLP (zero-initialized last layer), softmax(τ=1.0)
36
+
37
+ Defaults: l=32, d=16, l′=64, n=16, w=512; GQA groups=2.
38
+
39
+ ## Performance & Metrics (example targets)
40
+
41
+ - A100 40GB: ≥600 tok/s; TTFT ≤ 350 ms (batch=1, seq=128)
42
+ - RTX 4090: ≥400 tok/s; TTFT ≤ 450 ms
43
+ - CPU: ≥10 tok/s; TTFT ≤ 2.0 s
44
+
45
+ ## Intended Use / Limitations
46
+
47
+ - Toy assistant and demos; not suitable for high-stakes use.
48
+
49
+ ## Memory Budget (KV Cache)
50
+
51
+ - Standard LM approx: Mem ≈ t × H × (d_k + d_v) × bytes_per_elem
52
+ - NSA decode (M0): Mem ≈ (min(w, t) + n × l′) × H × (d_k + d_v) × bytes_per_elem
53
+ - Example (w=512, n=16, l′=64): tokens_cached ≈ min(512, t) + 1024 (FP16 → a few MiB for 117M dims)
54
+
55
+ ## Notes
56
+
57
+ - Tokenizer: byte-level tokenizer (vocab=256). This is not GPT‑2/BPE; input/output are raw UTF‑8 bytes.
58
+ - Generation cache: no KV cache in v1 (slower decode for long sequences). Planned follow‑up.
59
+ - Gate: initialized to uniform mixing by design (zero‑init last layer); differs from trained gate topology.
60
+ - Remote code uses SDPA-only paths and includes a safe fallback block if NSA is forcibly disabled via env.
61
+
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "nsa",
3
+ "architectures": [
4
+ "NSAForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_nsa.NSAConfig",
8
+ "AutoModelForCausalLM": "modeling_nsa.NSAForCausalLM",
9
+ "AutoTokenizer": [
10
+ "tokenization_nsa.NSAByteTokenizer",
11
+ null
12
+ ]
13
+ },
14
+ "vocab_size": 256,
15
+ "hidden_size": 768,
16
+ "num_hidden_layers": 12,
17
+ "num_attention_heads": 12,
18
+ "n_kv_groups": 2,
19
+ "d_k": 64,
20
+ "d_v": 64,
21
+ "max_position_embeddings": 2048,
22
+ "rope_theta": 10000,
23
+ "nsa": {
24
+ "branches": [
25
+ "cmp",
26
+ "sel",
27
+ "win"
28
+ ],
29
+ "window": 512,
30
+ "gqa_groups": 2,
31
+ "block": 32,
32
+ "stride": 16,
33
+ "sel_block": 64,
34
+ "sel_top_n": 16
35
+ }
36
+ }
configuration_nsa.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Remote code: configuration and modeling for NSA
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class NSAConfig(PretrainedConfig):
6
+ model_type = "nsa"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=50257,
11
+ hidden_size=768,
12
+ num_hidden_layers=12,
13
+ num_attention_heads=12,
14
+ n_kv_groups=1,
15
+ d_k=64,
16
+ d_v=64,
17
+ max_position_embeddings=2048,
18
+ rope_theta=10000,
19
+ nsa=None,
20
+ **kwargs,
21
+ ):
22
+ super().__init__(**kwargs)
23
+ self.vocab_size = vocab_size
24
+ self.hidden_size = hidden_size
25
+ self.num_hidden_layers = num_hidden_layers
26
+ self.num_attention_heads = num_attention_heads
27
+ self.n_kv_groups = n_kv_groups
28
+ self.d_k = d_k
29
+ self.d_v = d_v
30
+ self.max_position_embeddings = max_position_embeddings
31
+ self.rope_theta = rope_theta
32
+ self.nsa = nsa or {
33
+ "branches": ["cmp", "sel", "win"],
34
+ "window": 512,
35
+ "gqa_groups": n_kv_groups,
36
+ "block": 32,
37
+ "stride": 16,
38
+ "sel_block": 64,
39
+ "sel_top_n": 16,
40
+ }
logs/logs_extra_keys.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blocks.0.attn.gate.fc1.bias
2
+ blocks.0.attn.gate.fc1.weight
3
+ blocks.0.attn.gate.fc2.bias
4
+ blocks.0.attn.gate.fc2.weight
5
+ blocks.1.attn.gate.fc1.bias
6
+ blocks.1.attn.gate.fc1.weight
7
+ blocks.1.attn.gate.fc2.bias
8
+ blocks.1.attn.gate.fc2.weight
9
+ blocks.10.attn.gate.fc1.bias
10
+ blocks.10.attn.gate.fc1.weight
11
+ blocks.10.attn.gate.fc2.bias
12
+ blocks.10.attn.gate.fc2.weight
13
+ blocks.11.attn.gate.fc1.bias
14
+ blocks.11.attn.gate.fc1.weight
15
+ blocks.11.attn.gate.fc2.bias
16
+ blocks.11.attn.gate.fc2.weight
17
+ blocks.2.attn.gate.fc1.bias
18
+ blocks.2.attn.gate.fc1.weight
19
+ blocks.2.attn.gate.fc2.bias
20
+ blocks.2.attn.gate.fc2.weight
21
+ blocks.3.attn.gate.fc1.bias
22
+ blocks.3.attn.gate.fc1.weight
23
+ blocks.3.attn.gate.fc2.bias
24
+ blocks.3.attn.gate.fc2.weight
25
+ blocks.4.attn.gate.fc1.bias
26
+ blocks.4.attn.gate.fc1.weight
27
+ blocks.4.attn.gate.fc2.bias
28
+ blocks.4.attn.gate.fc2.weight
29
+ blocks.5.attn.gate.fc1.bias
30
+ blocks.5.attn.gate.fc1.weight
31
+ blocks.5.attn.gate.fc2.bias
32
+ blocks.5.attn.gate.fc2.weight
33
+ blocks.6.attn.gate.fc1.bias
34
+ blocks.6.attn.gate.fc1.weight
35
+ blocks.6.attn.gate.fc2.bias
36
+ blocks.6.attn.gate.fc2.weight
37
+ blocks.7.attn.gate.fc1.bias
38
+ blocks.7.attn.gate.fc1.weight
39
+ blocks.7.attn.gate.fc2.bias
40
+ blocks.7.attn.gate.fc2.weight
41
+ blocks.8.attn.gate.fc1.bias
42
+ blocks.8.attn.gate.fc1.weight
43
+ blocks.8.attn.gate.fc2.bias
44
+ blocks.8.attn.gate.fc2.weight
45
+ blocks.9.attn.gate.fc1.bias
46
+ blocks.9.attn.gate.fc1.weight
47
+ blocks.9.attn.gate.fc2.bias
48
+ blocks.9.attn.gate.fc2.weight
49
+ norm_f.weight
logs/logs_mapping.json ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mapped": [
3
+ "model.blocks.0.attn.W_K_cmp.weight",
4
+ "model.blocks.0.attn.W_K_sel.weight",
5
+ "model.blocks.0.attn.W_K_win.weight",
6
+ "model.blocks.0.attn.W_Q.weight",
7
+ "model.blocks.0.attn.W_V_cmp.weight",
8
+ "model.blocks.0.attn.W_V_sel.weight",
9
+ "model.blocks.0.attn.W_V_win.weight",
10
+ "model.blocks.0.attn.out.weight",
11
+ "model.blocks.0.mlp.fc1.weight",
12
+ "model.blocks.0.mlp.fc2.weight",
13
+ "model.blocks.0.norm1.weight",
14
+ "model.blocks.0.norm2.weight",
15
+ "model.blocks.1.attn.W_K_cmp.weight",
16
+ "model.blocks.1.attn.W_K_sel.weight",
17
+ "model.blocks.1.attn.W_K_win.weight",
18
+ "model.blocks.1.attn.W_Q.weight",
19
+ "model.blocks.1.attn.W_V_cmp.weight",
20
+ "model.blocks.1.attn.W_V_sel.weight",
21
+ "model.blocks.1.attn.W_V_win.weight",
22
+ "model.blocks.1.attn.out.weight",
23
+ "model.blocks.1.mlp.fc1.weight",
24
+ "model.blocks.1.mlp.fc2.weight",
25
+ "model.blocks.1.norm1.weight",
26
+ "model.blocks.1.norm2.weight",
27
+ "model.blocks.10.attn.W_K_cmp.weight",
28
+ "model.blocks.10.attn.W_K_sel.weight",
29
+ "model.blocks.10.attn.W_K_win.weight",
30
+ "model.blocks.10.attn.W_Q.weight",
31
+ "model.blocks.10.attn.W_V_cmp.weight",
32
+ "model.blocks.10.attn.W_V_sel.weight",
33
+ "model.blocks.10.attn.W_V_win.weight",
34
+ "model.blocks.10.attn.out.weight",
35
+ "model.blocks.10.mlp.fc1.weight",
36
+ "model.blocks.10.mlp.fc2.weight",
37
+ "model.blocks.10.norm1.weight",
38
+ "model.blocks.10.norm2.weight",
39
+ "model.blocks.11.attn.W_K_cmp.weight",
40
+ "model.blocks.11.attn.W_K_sel.weight",
41
+ "model.blocks.11.attn.W_K_win.weight",
42
+ "model.blocks.11.attn.W_Q.weight",
43
+ "model.blocks.11.attn.W_V_cmp.weight",
44
+ "model.blocks.11.attn.W_V_sel.weight",
45
+ "model.blocks.11.attn.W_V_win.weight",
46
+ "model.blocks.11.attn.out.weight",
47
+ "model.blocks.11.mlp.fc1.weight",
48
+ "model.blocks.11.mlp.fc2.weight",
49
+ "model.blocks.11.norm1.weight",
50
+ "model.blocks.11.norm2.weight",
51
+ "model.blocks.2.attn.W_K_cmp.weight",
52
+ "model.blocks.2.attn.W_K_sel.weight",
53
+ "model.blocks.2.attn.W_K_win.weight",
54
+ "model.blocks.2.attn.W_Q.weight",
55
+ "model.blocks.2.attn.W_V_cmp.weight",
56
+ "model.blocks.2.attn.W_V_sel.weight",
57
+ "model.blocks.2.attn.W_V_win.weight",
58
+ "model.blocks.2.attn.out.weight",
59
+ "model.blocks.2.mlp.fc1.weight",
60
+ "model.blocks.2.mlp.fc2.weight",
61
+ "model.blocks.2.norm1.weight",
62
+ "model.blocks.2.norm2.weight",
63
+ "model.blocks.3.attn.W_K_cmp.weight",
64
+ "model.blocks.3.attn.W_K_sel.weight",
65
+ "model.blocks.3.attn.W_K_win.weight",
66
+ "model.blocks.3.attn.W_Q.weight",
67
+ "model.blocks.3.attn.W_V_cmp.weight",
68
+ "model.blocks.3.attn.W_V_sel.weight",
69
+ "model.blocks.3.attn.W_V_win.weight",
70
+ "model.blocks.3.attn.out.weight",
71
+ "model.blocks.3.mlp.fc1.weight",
72
+ "model.blocks.3.mlp.fc2.weight",
73
+ "model.blocks.3.norm1.weight",
74
+ "model.blocks.3.norm2.weight",
75
+ "model.blocks.4.attn.W_K_cmp.weight",
76
+ "model.blocks.4.attn.W_K_sel.weight",
77
+ "model.blocks.4.attn.W_K_win.weight",
78
+ "model.blocks.4.attn.W_Q.weight",
79
+ "model.blocks.4.attn.W_V_cmp.weight",
80
+ "model.blocks.4.attn.W_V_sel.weight",
81
+ "model.blocks.4.attn.W_V_win.weight",
82
+ "model.blocks.4.attn.out.weight",
83
+ "model.blocks.4.mlp.fc1.weight",
84
+ "model.blocks.4.mlp.fc2.weight",
85
+ "model.blocks.4.norm1.weight",
86
+ "model.blocks.4.norm2.weight",
87
+ "model.blocks.5.attn.W_K_cmp.weight",
88
+ "model.blocks.5.attn.W_K_sel.weight",
89
+ "model.blocks.5.attn.W_K_win.weight",
90
+ "model.blocks.5.attn.W_Q.weight",
91
+ "model.blocks.5.attn.W_V_cmp.weight",
92
+ "model.blocks.5.attn.W_V_sel.weight",
93
+ "model.blocks.5.attn.W_V_win.weight",
94
+ "model.blocks.5.attn.out.weight",
95
+ "model.blocks.5.mlp.fc1.weight",
96
+ "model.blocks.5.mlp.fc2.weight",
97
+ "model.blocks.5.norm1.weight",
98
+ "model.blocks.5.norm2.weight",
99
+ "model.blocks.6.attn.W_K_cmp.weight",
100
+ "model.blocks.6.attn.W_K_sel.weight",
101
+ "model.blocks.6.attn.W_K_win.weight",
102
+ "model.blocks.6.attn.W_Q.weight",
103
+ "model.blocks.6.attn.W_V_cmp.weight",
104
+ "model.blocks.6.attn.W_V_sel.weight",
105
+ "model.blocks.6.attn.W_V_win.weight",
106
+ "model.blocks.6.attn.out.weight",
107
+ "model.blocks.6.mlp.fc1.weight",
108
+ "model.blocks.6.mlp.fc2.weight",
109
+ "model.blocks.6.norm1.weight",
110
+ "model.blocks.6.norm2.weight",
111
+ "model.blocks.7.attn.W_K_cmp.weight",
112
+ "model.blocks.7.attn.W_K_sel.weight",
113
+ "model.blocks.7.attn.W_K_win.weight",
114
+ "model.blocks.7.attn.W_Q.weight",
115
+ "model.blocks.7.attn.W_V_cmp.weight",
116
+ "model.blocks.7.attn.W_V_sel.weight",
117
+ "model.blocks.7.attn.W_V_win.weight",
118
+ "model.blocks.7.attn.out.weight",
119
+ "model.blocks.7.mlp.fc1.weight",
120
+ "model.blocks.7.mlp.fc2.weight",
121
+ "model.blocks.7.norm1.weight",
122
+ "model.blocks.7.norm2.weight",
123
+ "model.blocks.8.attn.W_K_cmp.weight",
124
+ "model.blocks.8.attn.W_K_sel.weight",
125
+ "model.blocks.8.attn.W_K_win.weight",
126
+ "model.blocks.8.attn.W_Q.weight",
127
+ "model.blocks.8.attn.W_V_cmp.weight",
128
+ "model.blocks.8.attn.W_V_sel.weight",
129
+ "model.blocks.8.attn.W_V_win.weight",
130
+ "model.blocks.8.attn.out.weight",
131
+ "model.blocks.8.mlp.fc1.weight",
132
+ "model.blocks.8.mlp.fc2.weight",
133
+ "model.blocks.8.norm1.weight",
134
+ "model.blocks.8.norm2.weight",
135
+ "model.blocks.9.attn.W_K_cmp.weight",
136
+ "model.blocks.9.attn.W_K_sel.weight",
137
+ "model.blocks.9.attn.W_K_win.weight",
138
+ "model.blocks.9.attn.W_Q.weight",
139
+ "model.blocks.9.attn.W_V_cmp.weight",
140
+ "model.blocks.9.attn.W_V_sel.weight",
141
+ "model.blocks.9.attn.W_V_win.weight",
142
+ "model.blocks.9.attn.out.weight",
143
+ "model.blocks.9.mlp.fc1.weight",
144
+ "model.blocks.9.mlp.fc2.weight",
145
+ "model.blocks.9.norm1.weight",
146
+ "model.blocks.9.norm2.weight",
147
+ "model.embed.weight",
148
+ "model.lm_head.weight"
149
+ ],
150
+ "missing": [
151
+ "model.blocks.0.attn.g1.weight",
152
+ "model.blocks.0.attn.g2.weight",
153
+ "model.blocks.1.attn.g1.weight",
154
+ "model.blocks.1.attn.g2.weight",
155
+ "model.blocks.10.attn.g1.weight",
156
+ "model.blocks.10.attn.g2.weight",
157
+ "model.blocks.11.attn.g1.weight",
158
+ "model.blocks.11.attn.g2.weight",
159
+ "model.blocks.2.attn.g1.weight",
160
+ "model.blocks.2.attn.g2.weight",
161
+ "model.blocks.3.attn.g1.weight",
162
+ "model.blocks.3.attn.g2.weight",
163
+ "model.blocks.4.attn.g1.weight",
164
+ "model.blocks.4.attn.g2.weight",
165
+ "model.blocks.5.attn.g1.weight",
166
+ "model.blocks.5.attn.g2.weight",
167
+ "model.blocks.6.attn.g1.weight",
168
+ "model.blocks.6.attn.g2.weight",
169
+ "model.blocks.7.attn.g1.weight",
170
+ "model.blocks.7.attn.g2.weight",
171
+ "model.blocks.8.attn.g1.weight",
172
+ "model.blocks.8.attn.g2.weight",
173
+ "model.blocks.9.attn.g1.weight",
174
+ "model.blocks.9.attn.g2.weight",
175
+ "model.norm.bias",
176
+ "model.norm.weight"
177
+ ],
178
+ "extra": [
179
+ "blocks.0.attn.gate.fc1.bias",
180
+ "blocks.0.attn.gate.fc1.weight",
181
+ "blocks.0.attn.gate.fc2.bias",
182
+ "blocks.0.attn.gate.fc2.weight",
183
+ "blocks.1.attn.gate.fc1.bias",
184
+ "blocks.1.attn.gate.fc1.weight",
185
+ "blocks.1.attn.gate.fc2.bias",
186
+ "blocks.1.attn.gate.fc2.weight",
187
+ "blocks.10.attn.gate.fc1.bias",
188
+ "blocks.10.attn.gate.fc1.weight",
189
+ "blocks.10.attn.gate.fc2.bias",
190
+ "blocks.10.attn.gate.fc2.weight",
191
+ "blocks.11.attn.gate.fc1.bias",
192
+ "blocks.11.attn.gate.fc1.weight",
193
+ "blocks.11.attn.gate.fc2.bias",
194
+ "blocks.11.attn.gate.fc2.weight",
195
+ "blocks.2.attn.gate.fc1.bias",
196
+ "blocks.2.attn.gate.fc1.weight",
197
+ "blocks.2.attn.gate.fc2.bias",
198
+ "blocks.2.attn.gate.fc2.weight",
199
+ "blocks.3.attn.gate.fc1.bias",
200
+ "blocks.3.attn.gate.fc1.weight",
201
+ "blocks.3.attn.gate.fc2.bias",
202
+ "blocks.3.attn.gate.fc2.weight",
203
+ "blocks.4.attn.gate.fc1.bias",
204
+ "blocks.4.attn.gate.fc1.weight",
205
+ "blocks.4.attn.gate.fc2.bias",
206
+ "blocks.4.attn.gate.fc2.weight",
207
+ "blocks.5.attn.gate.fc1.bias",
208
+ "blocks.5.attn.gate.fc1.weight",
209
+ "blocks.5.attn.gate.fc2.bias",
210
+ "blocks.5.attn.gate.fc2.weight",
211
+ "blocks.6.attn.gate.fc1.bias",
212
+ "blocks.6.attn.gate.fc1.weight",
213
+ "blocks.6.attn.gate.fc2.bias",
214
+ "blocks.6.attn.gate.fc2.weight",
215
+ "blocks.7.attn.gate.fc1.bias",
216
+ "blocks.7.attn.gate.fc1.weight",
217
+ "blocks.7.attn.gate.fc2.bias",
218
+ "blocks.7.attn.gate.fc2.weight",
219
+ "blocks.8.attn.gate.fc1.bias",
220
+ "blocks.8.attn.gate.fc1.weight",
221
+ "blocks.8.attn.gate.fc2.bias",
222
+ "blocks.8.attn.gate.fc2.weight",
223
+ "blocks.9.attn.gate.fc1.bias",
224
+ "blocks.9.attn.gate.fc1.weight",
225
+ "blocks.9.attn.gate.fc2.bias",
226
+ "blocks.9.attn.gate.fc2.weight",
227
+ "norm_f.weight"
228
+ ]
229
+ }
logs/logs_missing_keys.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model.blocks.0.attn.g1.weight
2
+ model.blocks.0.attn.g2.weight
3
+ model.blocks.1.attn.g1.weight
4
+ model.blocks.1.attn.g2.weight
5
+ model.blocks.10.attn.g1.weight
6
+ model.blocks.10.attn.g2.weight
7
+ model.blocks.11.attn.g1.weight
8
+ model.blocks.11.attn.g2.weight
9
+ model.blocks.2.attn.g1.weight
10
+ model.blocks.2.attn.g2.weight
11
+ model.blocks.3.attn.g1.weight
12
+ model.blocks.3.attn.g2.weight
13
+ model.blocks.4.attn.g1.weight
14
+ model.blocks.4.attn.g2.weight
15
+ model.blocks.5.attn.g1.weight
16
+ model.blocks.5.attn.g2.weight
17
+ model.blocks.6.attn.g1.weight
18
+ model.blocks.6.attn.g2.weight
19
+ model.blocks.7.attn.g1.weight
20
+ model.blocks.7.attn.g2.weight
21
+ model.blocks.8.attn.g1.weight
22
+ model.blocks.8.attn.g2.weight
23
+ model.blocks.9.attn.g1.weight
24
+ model.blocks.9.attn.g2.weight
25
+ model.norm.bias
26
+ model.norm.weight
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92e303af798306020bcf0b1a6293a9e88027887b70d61d110fc4cba274cedf66
3
+ size 320203152
modeling_nsa.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Remote code: configuration and modeling for NSA
2
+ import math
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from torch import nn
7
+ from transformers import PreTrainedModel
8
+ from transformers.generation.utils import GenerationMixin
9
+ from transformers.modeling_outputs import CausalLMOutput
10
+
11
+ from .configuration_nsa import NSAConfig
12
+ _HAS_NSA = False # Embedded NSA is provided below; no external import required.
13
+
14
+
15
+ class RMSNorm(nn.Module):
16
+ def __init__(self, dim: int, eps: float = 1e-6) -> None:
17
+ super().__init__()
18
+ self.weight = nn.Parameter(torch.ones(dim))
19
+ self.eps = eps
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
23
+ return (x * rms) * self.weight
24
+
25
+
26
+ class MLP(nn.Module):
27
+ def __init__(self, dim: int, hidden_mult: int = 4) -> None:
28
+ super().__init__()
29
+ h = hidden_mult * dim
30
+ self.fc1 = nn.Linear(dim, h, bias=False)
31
+ self.fc2 = nn.Linear(h, dim, bias=False)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ return self.fc2(torch.nn.functional.silu(self.fc1(x)))
35
+
36
+
37
+ def _rope(q: torch.Tensor) -> torch.Tensor:
38
+ B, S, D = q.shape[0], q.shape[2], q.shape[-1]
39
+ if D % 2 != 0:
40
+ return q
41
+ device = q.device
42
+ half = D // 2
43
+ pos = torch.arange(S, device=device).float().unsqueeze(-1)
44
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, half, device=device).float() / half))
45
+ angles = pos * inv_freq
46
+ cos = angles.cos().view(1, 1, S, half)
47
+ sin = angles.sin().view(1, 1, S, half)
48
+ q1, q2 = q[..., :half], q[..., half:]
49
+ return torch.cat([q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1)
50
+
51
+
52
+ def _avg_pool_time(x: torch.Tensor, kernel: int, stride: int) -> torch.Tensor:
53
+ if x.shape[2] < kernel:
54
+ return x[..., :0, :]
55
+ xt = x.permute(0, 3, 1, 2).contiguous()
56
+ y = torch.nn.functional.avg_pool2d(xt, kernel_size=(1, kernel), stride=(1, stride))
57
+ return y.permute(0, 2, 3, 1).contiguous()
58
+
59
+
60
+ def _window_mask(q: torch.Tensor, S: int, w: int) -> torch.Tensor:
61
+ B, h = q.shape[0], q.shape[1]
62
+ device = q.device
63
+ row = torch.arange(S, device=device).view(S, 1)
64
+ col = torch.arange(S, device=device).view(1, S)
65
+ allowed = (col <= row) & (col >= (row - (w - 1)))
66
+ M = torch.full((S, S), float('-inf'), device=device, dtype=q.dtype)
67
+ M.masked_fill_(allowed, 0.0)
68
+ return M.view(1, 1, S, S).expand(B, h, S, S)
69
+
70
+
71
+ def _selection_blocks(scores: torch.Tensor, l_sel: int, n_sel: int) -> torch.Tensor:
72
+ B, h, S = scores.shape
73
+ n_blocks = max(1, (S + l_sel - 1) // l_sel)
74
+ # Pad to multiple of l_sel
75
+ pad = n_blocks * l_sel - S
76
+ if pad > 0:
77
+ scores = torch.nn.functional.pad(scores, (0, pad), value=-1e9)
78
+ blk_scores = scores.view(B, h, n_blocks, l_sel).max(dim=-1).values
79
+ k = min(n_sel, n_blocks)
80
+ return torch.topk(blk_scores, k=k, dim=-1).indices
81
+
82
+
83
+ class EmbeddedNSAAttention(nn.Module):
84
+ def __init__(self, dim: int, n_heads: int, n_kv_groups: int, d_k: int, d_v: int,
85
+ l: int, d: int, l_sel: int, n_sel: int, w: int) -> None:
86
+ super().__init__()
87
+ self.n_heads = n_heads
88
+ self.n_kv_groups = n_kv_groups
89
+ self.d_k = d_k
90
+ self.d_v = d_v
91
+ self.l = l
92
+ self.stride = d
93
+ self.l_sel = l_sel
94
+ self.n_sel = n_sel
95
+ self.w = w
96
+ self.W_Q = nn.Linear(dim, n_heads * d_k, bias=False)
97
+ self.W_K_cmp = nn.Linear(dim, n_kv_groups * d_k, bias=False)
98
+ self.W_V_cmp = nn.Linear(dim, n_kv_groups * d_v, bias=False)
99
+ self.W_K_sel = nn.Linear(dim, n_kv_groups * d_k, bias=False)
100
+ self.W_V_sel = nn.Linear(dim, n_kv_groups * d_v, bias=False)
101
+ self.W_K_win = nn.Linear(dim, n_kv_groups * d_k, bias=False)
102
+ self.W_V_win = nn.Linear(dim, n_kv_groups * d_v, bias=False)
103
+ self.g1 = nn.Linear(dim, max(1, dim // 4), bias=False)
104
+ self.g2 = nn.Linear(max(1, dim // 4), 3, bias=False)
105
+ nn.init.zeros_(self.g2.weight)
106
+ self.out = nn.Linear(n_heads * d_v, dim, bias=False)
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ B, S, D = x.shape
110
+ h, dk, dv = self.n_heads, self.d_k, self.d_v
111
+ Q = self.W_Q(x).view(B, S, h, dk).transpose(1, 2) # [B,h,S,dk]
112
+ g = max(1, self.n_kv_groups)
113
+ r = max(1, h // g)
114
+ # Project per-group K/V then broadcast to heads
115
+ Kc_g = self.W_K_cmp(x).view(B, S, g, dk).permute(0, 2, 1, 3) # [B,g,S,dk]
116
+ Vc_g = self.W_V_cmp(x).view(B, S, g, dv).permute(0, 2, 1, 3)
117
+ Ks_g = self.W_K_sel(x).view(B, S, g, dk).permute(0, 2, 1, 3)
118
+ Vs_g = self.W_V_sel(x).view(B, S, g, dv).permute(0, 2, 1, 3)
119
+ Kw_g = self.W_K_win(x).view(B, S, g, dk).permute(0, 2, 1, 3)
120
+ Vw_g = self.W_V_win(x).view(B, S, g, dv).permute(0, 2, 1, 3)
121
+ # Broadcast groups to heads
122
+ def _bcast_to_heads(T):
123
+ return T.unsqueeze(1).expand(B, r, g, S, T.shape[-1]).reshape(B, h, S, T.shape[-1])
124
+ Kc = _bcast_to_heads(Kc_g)
125
+ Vc = _bcast_to_heads(Vc_g)
126
+ Ks = _bcast_to_heads(Ks_g)
127
+ Vs = _bcast_to_heads(Vs_g)
128
+ Kw = _bcast_to_heads(Kw_g)
129
+ Vw = _bcast_to_heads(Vw_g)
130
+
131
+ # RoPE
132
+ Qr = _rope(Q.transpose(1, 2)).transpose(1, 2)
133
+ Kc_r = _rope(Kc.transpose(1, 2)).transpose(1, 2)
134
+ Ks_r = _rope(Ks.transpose(1, 2)).transpose(1, 2)
135
+ Kw_r = _rope(Kw.transpose(1, 2)).transpose(1, 2)
136
+
137
+ # Compressed: average-pool along time
138
+ Kc_p = _avg_pool_time(Kc_r, kernel=max(1, self.stride), stride=max(1, self.stride))
139
+ Vc_p = _avg_pool_time(Vc, kernel=max(1, self.stride), stride=max(1, self.stride))
140
+ O_cmp = torch.nn.functional.scaled_dot_product_attention(Qr, Kc_p, Vc_p, is_causal=True)
141
+
142
+ # Selection: naive top-n blocks (global), enforce causal via triangular mask
143
+ scores = (Qr * Ks_r).mean(dim=-1) # [B,h,S]
144
+ blk_idx = _selection_blocks(scores, self.l_sel, self.n_sel) # [B,h,n]
145
+ n_blocks = max(1, (S + self.l_sel - 1) // self.l_sel)
146
+ keep = torch.zeros((B, h, n_blocks), device=x.device, dtype=torch.bool)
147
+ keep.scatter_(2, blk_idx, True)
148
+ keep = keep.unsqueeze(-1).expand(B, h, n_blocks, self.l_sel).reshape(B, h, -1)[:, :, :S]
149
+ logits = torch.matmul(Qr / math.sqrt(dk), Ks_r.transpose(-2, -1)) # [B,h,S,S]
150
+ tri = torch.triu(torch.ones((S, S), device=x.device, dtype=torch.bool), diagonal=1)
151
+ logits = logits.masked_fill(tri, float('-inf'))
152
+ sel_mask = torch.where(keep.unsqueeze(2).expand(B, h, S, S), torch.zeros((), device=x.device, dtype=Qr.dtype), torch.full((), float('-inf'), device=x.device, dtype=Qr.dtype))
153
+ P = torch.nn.functional.softmax(logits + sel_mask, dim=-1)
154
+ O_sel = torch.matmul(P, Vs)
155
+
156
+ # Sliding window
157
+ M = _window_mask(Qr, S, max(1, self.w))
158
+ logits_w = torch.matmul(Qr / math.sqrt(dk), Kw_r.transpose(-2, -1)) + M
159
+ P_w = torch.nn.functional.softmax(logits_w, dim=-1)
160
+ O_win = torch.matmul(P_w, Vw)
161
+
162
+ # Gate & mix
163
+ gate = self.g2(torch.nn.functional.silu(self.g1(x))) # [B,S,3]
164
+ gate = torch.nn.functional.softmax(gate, dim=-1)
165
+ gc, gs, gw = gate[..., 0:1], gate[..., 1:2], gate[..., 2:3]
166
+ O = gc.unsqueeze(1) * O_cmp + gs.unsqueeze(1) * O_sel + gw.unsqueeze(1) * O_win
167
+ O = O.transpose(1, 2).reshape(B, S, h * dv)
168
+ return self.out(O)
169
+
170
+ class SimpleAttention(nn.Module):
171
+ def __init__(self, dim: int, n_heads: int, d_k: int, d_v: int) -> None:
172
+ super().__init__()
173
+ self.n_heads = n_heads
174
+ self.d_k = d_k
175
+ self.d_v = d_v
176
+ self.q_proj = nn.Linear(dim, n_heads * d_k, bias=False)
177
+ self.k_proj = nn.Linear(dim, n_heads * d_k, bias=False)
178
+ self.v_proj = nn.Linear(dim, n_heads * d_v, bias=False)
179
+ self.out = nn.Linear(n_heads * d_v, dim, bias=False)
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ B, S, D = x.shape
183
+ h, dk, dv = self.n_heads, self.d_k, self.d_v
184
+ q = self.q_proj(x).view(B, S, h, dk).transpose(1, 2) # [B,h,S,dk]
185
+ k = self.k_proj(x).view(B, S, h, dk).transpose(1, 2) # [B,h,S,dk]
186
+ v = self.v_proj(x).view(B, S, h, dv).transpose(1, 2) # [B,h,S,dv]
187
+ attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
188
+ attn = attn.transpose(1, 2).contiguous().view(B, S, h * dv)
189
+ return self.out(attn)
190
+
191
+
192
+ class SimpleBlock(nn.Module):
193
+ def __init__(self, dim: int, n_heads: int, d_k: int, d_v: int) -> None:
194
+ super().__init__()
195
+ self.norm1 = RMSNorm(dim)
196
+ self.attn = SimpleAttention(dim, n_heads, d_k, d_v)
197
+ self.norm2 = RMSNorm(dim)
198
+ self.mlp = MLP(dim)
199
+
200
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
201
+ x = x + self.attn(self.norm1(x))
202
+ x = x + self.mlp(self.norm2(x))
203
+ return x
204
+
205
+
206
+ class NSABlockRemote(nn.Module):
207
+ """Transformer block with embedded NSA attention, pre/post RMSNorm, and MLP."""
208
+ def __init__(self, dim: int, n_heads: int, n_kv_groups: int, d_k: int, d_v: int,
209
+ l: int, d: int, l_sel: int, n_sel: int, w: int) -> None:
210
+ super().__init__()
211
+ self.norm1 = RMSNorm(dim)
212
+ self.attn = EmbeddedNSAAttention(dim, n_heads, n_kv_groups, d_k, d_v, l, d, l_sel, n_sel, w)
213
+ self.norm2 = RMSNorm(dim)
214
+ self.mlp = MLP(dim)
215
+
216
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
217
+ x = x + self.attn(self.norm1(x))
218
+ x = x + self.mlp(self.norm2(x))
219
+ return x
220
+
221
+ class NSATinyLM(nn.Module):
222
+ def __init__(self, config: NSAConfig):
223
+ super().__init__()
224
+ self.config = config
225
+ self.vocab_size = int(config.vocab_size)
226
+ self.hidden_size = int(config.hidden_size)
227
+ self.num_hidden_layers = int(config.num_hidden_layers)
228
+ self.num_attention_heads = int(config.num_attention_heads)
229
+ self.n_kv_groups = int(getattr(config, "n_kv_groups", 1))
230
+ self.d_k = int(getattr(config, "d_k", self.hidden_size // self.num_attention_heads))
231
+ self.d_v = int(getattr(config, "d_v", self.hidden_size // self.num_attention_heads))
232
+ nsa = config.nsa or {}
233
+ self.l = int(nsa.get("block", 32))
234
+ self.d = int(nsa.get("stride", 16))
235
+ self.l_sel = int(nsa.get("sel_block", 64))
236
+ self.n_sel = int(nsa.get("sel_top_n", 16))
237
+ self.w = int(nsa.get("window", 512))
238
+
239
+ self.embed = nn.Embedding(self.vocab_size, self.hidden_size)
240
+ import os as _os
241
+ # Allow forcing simple fallback via env for integration tests
242
+ _force_simple = _os.getenv('NSA_REMOTE_FORCE_SIMPLE', '0').lower() in ('1','true','yes')
243
+ if _force_simple == False:
244
+ self.blocks = nn.ModuleList([
245
+ NSABlockRemote(
246
+ self.hidden_size,
247
+ self.num_attention_heads,
248
+ self.n_kv_groups,
249
+ self.d_k,
250
+ self.d_v,
251
+ self.l,
252
+ self.d,
253
+ self.l_sel,
254
+ self.n_sel,
255
+ self.w,
256
+ ) for _ in range(self.num_hidden_layers)
257
+ ])
258
+ else:
259
+ self.blocks = nn.ModuleList([
260
+ SimpleBlock(self.hidden_size, self.num_attention_heads, self.d_k, self.d_v)
261
+ for _ in range(self.num_hidden_layers)
262
+ ])
263
+ self.norm = nn.LayerNorm(self.hidden_size)
264
+ self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
265
+
266
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
267
+ x = self.embed(input_ids)
268
+ for blk in self.blocks:
269
+ x = blk(x)
270
+ x = self.norm(x)
271
+ logits = self.lm_head(x)
272
+ return logits
273
+
274
+
275
+ class NSAForCausalLM(PreTrainedModel, GenerationMixin):
276
+ config_class = NSAConfig
277
+ _no_split_modules = ["EmbeddedNSAAttention", "SimpleBlock"]
278
+
279
+ def __init__(self, config: NSAConfig):
280
+ super().__init__(config)
281
+ self.model = NSATinyLM(config)
282
+ self.post_init()
283
+
284
+ def get_input_embeddings(self):
285
+ return self.model.embed
286
+
287
+ def set_input_embeddings(self, new_emb):
288
+ self.model.embed = new_emb
289
+
290
+ def forward(
291
+ self,
292
+ input_ids: Optional[torch.LongTensor] = None,
293
+ attention_mask: Optional[torch.Tensor] = None,
294
+ labels: Optional[torch.LongTensor] = None,
295
+ **kwargs,
296
+ ):
297
+ if input_ids is None:
298
+ raise ValueError("input_ids is required")
299
+ logits = self.model(input_ids)
300
+ loss = None
301
+ if labels is not None:
302
+ # Shift for causal LM loss
303
+ shift_logits = logits[:, :-1, :].contiguous()
304
+ shift_labels = labels[:, 1:].contiguous()
305
+ loss_fct = torch.nn.CrossEntropyLoss()
306
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
307
+ return CausalLMOutput(loss=loss, logits=logits)
308
+
309
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
310
+ # No past_key_values cache: rerun full sequence. Works everywhere, slower at decode.
311
+ return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
nsa/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from __future__ import annotations
nsa/cache/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from __future__ import annotations
nsa/cache/kv_cache.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+
6
+ from nsa.core.block_index import BlockMeta
7
+
8
+
9
+ @dataclass
10
+ class NSA_KV:
11
+ K_sel: torch.Tensor # [B,G,S,Dk]
12
+ V_sel: torch.Tensor # [B,G,S,Dv]
13
+ K_win: torch.Tensor # [B,G,S,Dk]
14
+ V_win: torch.Tensor # [B,G,S,Dv]
15
+ # raw token-level seq for compressed branch
16
+ K_cmp_raw_seq: torch.Tensor # [B,G,S,Dk]
17
+ V_cmp_raw_seq: torch.Tensor # [B,G,S,Dv]
18
+ K_cmp: torch.Tensor # [B,G,S_cmp,Dk]
19
+ V_cmp: torch.Tensor # [B,G,S_cmp,Dv]
20
+ win_ptr: torch.Tensor # [B,G]
21
+ cmp_emit_next: torch.Tensor # [B,G]
22
+ meta: BlockMeta
23
+ reads_pred: torch.Tensor # [T] per decode step predicted total reads
24
+ reads_act_total: torch.Tensor # [T]
25
+ reads_act_sel: torch.Tensor # [T]
26
+ reads_act_cmp: torch.Tensor # [T]
27
+ reads_act_win: torch.Tensor # [T]
28
+
29
+ def update_selection_raw(self, K: torch.Tensor, V: torch.Tensor) -> None:
30
+ self.K_sel = torch.cat([self.K_sel, K], dim=2)
31
+ self.V_sel = torch.cat([self.V_sel, V], dim=2)
32
+
33
+ def update_window(self, K: torch.Tensor, V: torch.Tensor, w: int) -> None:
34
+ self.K_win = torch.cat([self.K_win, K], dim=2)
35
+ self.V_win = torch.cat([self.V_win, V], dim=2)
36
+ # keep last w tokens
37
+ if self.K_win.shape[2] > w:
38
+ self.K_win = self.K_win[:, :, -w:, :]
39
+ self.V_win = self.V_win[:, :, -w:, :]
40
+
41
+ def update_compressed(
42
+ self, K_raw_cmp: torch.Tensor, V_raw_cmp: torch.Tensor, l: int, d: int
43
+ ) -> None:
44
+ # M0 prefill path: rebuild fully using avg-pool ϕ handled upstream
45
+ self.K_cmp = K_raw_cmp
46
+ self.V_cmp = V_raw_cmp
47
+
48
+ def append_cmp_raw(self, K_raw_tok: torch.Tensor, V_raw_tok: torch.Tensor) -> None:
49
+ self.K_cmp_raw_seq = torch.cat([self.K_cmp_raw_seq, K_raw_tok], dim=2)
50
+ self.V_cmp_raw_seq = torch.cat([self.V_cmp_raw_seq, V_raw_tok], dim=2)
51
+
52
+ def append_reads_pred(self, value: int) -> None:
53
+ v = torch.tensor([value], dtype=torch.int64, device=self.K_sel.device)
54
+ self.reads_pred = torch.cat([self.reads_pred, v], dim=0) if self.reads_pred.numel() else v
55
+
56
+ def append_reads_actual(self, total: int, sel: int, cmp: int, win: int) -> None:
57
+ dev = self.K_sel.device
58
+
59
+ def cat_or_set(t: torch.Tensor, val: int) -> torch.Tensor:
60
+ v = torch.tensor([val], dtype=torch.int64, device=dev)
61
+ return torch.cat([t, v], dim=0) if t.numel() else v
62
+
63
+ self.reads_act_total = cat_or_set(self.reads_act_total, total)
64
+ self.reads_act_sel = cat_or_set(self.reads_act_sel, sel)
65
+ self.reads_act_cmp = cat_or_set(self.reads_act_cmp, cmp)
66
+ self.reads_act_win = cat_or_set(self.reads_act_win, win)
nsa/core/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NSA Core Modules — Map and Responsibilities
2
+
3
+ Purpose
4
+ - Quick orientation for contributors. Links to architecture and tests mapping.
5
+
6
+ Modules
7
+ - `nsa_attention.py`: Top‑level attention module. Branch wiring (cmp/sel/win), gate MLP (τ=1.0, zero‑init last layer), strict masks, decode caches (`K_sel/V_sel`, `K_win/V_win`), counters.
8
+ - `selection_scorer.py`: Selection pipeline — compute p_cmp, map to p_slc (Eq.9 CSR), group reduce (Eq.10), deterministic top‑n, range construction (v2 vectorized), NVTX tags.
9
+ - `block_index.py`: CSR for cmp→sel fractional overlaps, conversions (CSR↔COO), helpers.
10
+ - `compress_pool.py`: Compressed branch pooling ϕ, emission schedule (warmup l, stride d), RoPE ordering.
11
+ - `attention_kernels.py`: SDPA variants — packed selection, masked SDPA, varlen helpers; FA‑2 wrappers (opt‑in) for cmp/win.
12
+ - `packing.py`: Range packing, index normalization, adjacency merge/de‑dup.
13
+ - `rope.py`: RoPE application for Q and per‑branch K before ϕ.
14
+ - `flags.py`: Environment flags and routing toggles.
15
+ - `debug.py`, `collate.py`: Debug helpers and varlen collate utilities.
16
+
17
+ Key Invariants (guarded by tests)
18
+ - Strict causality masks (see `nsa/tests/test_masks.py`).
19
+ - Group consistency (Eq.10) (see `nsa/tests/test_group_consistency*.py`).
20
+ - Selection rules (tie‑break, merge/de‑dup/clamp) (see `nsa/tests/test_selection_*`, `test_ranges_normalization.py`).
21
+ - Decode reads counters formula (see `nsa/tests/test_decode_counters.py`).
22
+
23
+ References
24
+ - Architecture Overview: Documentation/Architecture/Overview.md
25
+ - Selection Semantics: Documentation/Architecture/Selection-Semantics.md
26
+ - Tests Index: Documentation/Tests/Index.md
27
+
nsa/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from __future__ import annotations
nsa/core/attention_kernels.py ADDED
@@ -0,0 +1,1403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ import time
4
+ from typing import Dict, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from nsa.core.debug import log
10
+ from nsa.core.packing import (
11
+ build_cu_seqlens_for_buckets,
12
+ build_length_buckets,
13
+ compute_compressed_lengths,
14
+ compute_sliding_lengths,
15
+ )
16
+ from nsa.kernels.flash_wrappers import (
17
+ attention_bgh,
18
+ attention_fa2_dense_batch,
19
+ attention_fa2_varlen,
20
+ fa2_supported,
21
+ fa2_supported_verbose,
22
+ is_flash_varlen_available,
23
+ )
24
+
25
+ # Simple grow-on-demand workspaces for varlen packing to avoid frequent allocations
26
+ _VARLEN_WS: Dict[Tuple, Dict[str, torch.Tensor]] = {}
27
+ _SEL_PACK_WS: Dict[Tuple, Dict[str, torch.Tensor]] = {}
28
+
29
+
30
+ def _env_int(name: str, default: int) -> int:
31
+ try:
32
+ v = int(os.getenv(name, str(default)))
33
+ return v
34
+ except Exception:
35
+ return default
36
+
37
+
38
+ def _env_int_bounded(name: str, default: int, min_val: int = 0, max_val: int = 10**8) -> int:
39
+ """Read integer from environment with bounds checking to prevent excessive memory allocation."""
40
+ try:
41
+ v = int(os.getenv(name, str(default)))
42
+ if v < min_val:
43
+ return min_val
44
+ if v > max_val:
45
+ # Log warning if value exceeds max
46
+ import warnings
47
+
48
+ warnings.warn(f"{name}={v} exceeds maximum {max_val}, clamping to {max_val}")
49
+ return max_val
50
+ return v
51
+ except Exception:
52
+ return default
53
+
54
+
55
+ def clear_varlen_workspaces() -> None:
56
+ """Optional memory cleanup: free varlen packing workspaces."""
57
+ _VARLEN_WS.clear()
58
+
59
+
60
+ def clear_selection_pack_workspaces() -> None:
61
+ """Optional memory cleanup: free selection pack workspaces."""
62
+ _SEL_PACK_WS.clear()
63
+
64
+
65
+ def _get_varlen_workspace(
66
+ device: torch.device,
67
+ dtype_q: torch.dtype,
68
+ dtype_k: torch.dtype,
69
+ dtype_v: torch.dtype,
70
+ h: int,
71
+ d_k: int,
72
+ d_v: int,
73
+ cap_N: int,
74
+ cap_total_k: int,
75
+ ) -> dict[str, torch.Tensor]:
76
+ key = (str(device), dtype_q, dtype_k, dtype_v, h, d_k, d_v)
77
+ ws = _VARLEN_WS.get(key)
78
+ need_new = ws is None
79
+ if not need_new:
80
+ q, k, v = ws["q"], ws["k"], ws["v"]
81
+ cuq, cuk = ws["cuq"], ws["cuk"]
82
+ need_new = (
83
+ q.shape[0] < cap_N
84
+ or k.shape[0] < cap_total_k
85
+ or v.shape[0] < cap_total_k
86
+ or cuq.numel() < (cap_N + 1)
87
+ or cuk.numel() < (cap_N + 1)
88
+ )
89
+ if need_new:
90
+ # Allow pre-sizing via env to avoid growth reallocations on long runs
91
+ # Bounded to prevent excessive memory allocation (max 1M rows, 100M total K/V)
92
+ reserve_N = _env_int_bounded("NSA_VARLEN_RESERVE_N", 0, 0, 10**6)
93
+ reserve_K = _env_int_bounded("NSA_VARLEN_RESERVE_K", 0, 0, 10**8)
94
+ new_N = max(cap_N, reserve_N, 1)
95
+ new_K = max(cap_total_k, reserve_K, 1)
96
+ ws = {
97
+ "q": torch.empty((new_N, h, d_k), dtype=dtype_q, device=device),
98
+ "k": torch.empty((new_K, h, d_k), dtype=dtype_k, device=device),
99
+ "v": torch.empty((new_K, h, d_v), dtype=dtype_v, device=device),
100
+ "cuq": torch.empty((new_N + 1,), dtype=torch.int32, device=device),
101
+ "cuk": torch.empty((new_N + 1,), dtype=torch.int32, device=device),
102
+ }
103
+ _VARLEN_WS[key] = ws
104
+ return ws
105
+
106
+
107
+ def batched_causal_attention_compressed(
108
+ Q: torch.Tensor, # [B,S,G,h,Dk]
109
+ K_cmp: torch.Tensor, # [B,G,S_cmp,Dk]
110
+ V_cmp: torch.Tensor, # [B,G,S_cmp,Dv]
111
+ l: int,
112
+ d: int,
113
+ ) -> torch.Tensor: # [B,S,G,h,Dv]
114
+ """
115
+ Compressed branch attention with per-row causal mask derived from emission schedule.
116
+ We cannot rely on is_causal=True due to S_q != S_kv and variable allowed lengths per t.
117
+ """
118
+ B, S, G, h, Dk = Q.shape
119
+ S_cmp = K_cmp.shape[2]
120
+ device = Q.device
121
+
122
+ # num_cmp(t) = 0 if t+1 < l else floor((t+1 - l) / d) + 1, clamped to S_cmp
123
+ tpos = torch.arange(S, device=device)
124
+ num_cmp = torch.where(tpos + 1 < l, 0, ((tpos + 1 - l) // d) + 1).clamp(max=S_cmp)
125
+ col = torch.arange(S_cmp, device=device).view(1, S_cmp)
126
+ # disallowed mask: True means masked
127
+ col >= num_cmp.view(S, 1) # [S,S_cmp]
128
+ # Enforce token-level causality as well: no compressed tokens emitted from future blocks beyond t
129
+ # When l=d=1, S_cmp == S and this reduces to standard causal
130
+
131
+ # Parity-first: exact per-t using attention_bgh
132
+ out = torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
133
+ log("cmp.begin", B=B, S=S, S_cmp=int(S_cmp), l=l, d=d)
134
+ for t in range(S):
135
+ L = int(num_cmp[t].item())
136
+ if L <= 0:
137
+ out[:, t] = 0.0
138
+ continue
139
+ q_t = Q[:, t]
140
+ k_t = K_cmp[:, :, :L, :]
141
+ v_t = V_cmp[:, :, :L, :]
142
+ out[:, t] = attention_bgh(q_t, k_t, v_t, causal=True)
143
+ log("cmp.step", t=int(t), L=L)
144
+ return out
145
+
146
+
147
+ def sliding_window_attention(
148
+ Q: torch.Tensor, # [B,S,G,h,Dk]
149
+ K: torch.Tensor, # [B,G,S,Dk]
150
+ V: torch.Tensor, # [B,G,S,Dv]
151
+ w: int,
152
+ ) -> torch.Tensor: # [B,S,G,h,Dv]
153
+ B, S, G, h, Dk = Q.shape
154
+ # Empty or zero window → zeros
155
+ if w <= 0 or K.shape[2] == 0 or S == 0:
156
+ return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device)
157
+ device = Q.device
158
+ # Build banded causal mask once: allowed keys per row t are [t-w+1 .. t]
159
+ row = torch.arange(S, device=device).view(S, 1)
160
+ col = torch.arange(S, device=device).view(1, S)
161
+ allowed = (col <= row) & (col >= (row - (w - 1))) # [S,S]
162
+ # Use additive float mask with -inf for disallowed positions to avoid NaNs
163
+ # across SDPA backends/dtypes. Shape: [S,S] then broadcast to [B,G*h,S,S].
164
+ Mf2d = torch.full((S, S), float("-inf"), dtype=Q.dtype, device=device)
165
+ Mf2d.masked_fill_(allowed, 0.0)
166
+ # Prepare SDPA tensors: [B, G*h, S, D*]
167
+ Qf = Q.reshape(B, S, G * h, Dk).transpose(1, 2).contiguous() # [B,G*h,S,Dk]
168
+ Kf = K.unsqueeze(2).expand(B, G, h, S, Dk).reshape(B, G * h, S, Dk).contiguous()
169
+ Vf = (
170
+ V.unsqueeze(2)
171
+ .expand(B, G, h, S, V.shape[-1])
172
+ .reshape(B, G * h, S, V.shape[-1])
173
+ .contiguous()
174
+ )
175
+ # Broadcast additive mask to [B,G*h,S,S]
176
+ Mf = Mf2d.view(1, 1, S, S).expand(B, G * h, S, S)
177
+ Of = F.scaled_dot_product_attention(Qf, Kf, Vf, attn_mask=Mf) # [B,G*h,S,Dv]
178
+ Of = Of.transpose(1, 2).reshape(B, S, G, h, V.shape[-1])
179
+ return Of
180
+
181
+
182
+ def grouped_selection_attention(
183
+ Q: torch.Tensor, # [B,S,G,h,Dk]
184
+ K: torch.Tensor, # [B,G,S_kv,Dk]
185
+ V: torch.Tensor, # [B,G,S_kv,Dv]
186
+ ranges: torch.Tensor, # [B,S,G,n,2]
187
+ ) -> torch.Tensor: # [B,S,G,h,Dv]
188
+ B, S, G, h, Dk = Q.shape
189
+ K.shape[2]
190
+
191
+ # Path 1: exact sequential-equivalence gather per (b,t,g)
192
+ out = torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device)
193
+ for b in range(B):
194
+ for t in range(S):
195
+ for g in range(G):
196
+ # build exact gather index list
197
+ idxs = []
198
+ for i in range(ranges.shape[3]):
199
+ s0 = int(ranges[b, t, g, i, 0].item())
200
+ e0 = int(ranges[b, t, g, i, 1].item())
201
+ if e0 > s0:
202
+ idxs.append(torch.arange(s0, e0, device=V.device))
203
+ if idxs:
204
+ idx = torch.cat(idxs)
205
+ k = K[b, g, idx] # [L,Dk]
206
+ v = V[b, g, idx] # [L,Dv]
207
+ q = Q[b, t, g] # [h,Dk]
208
+ # Expand per-head kv and add query-length dim for SDPA
209
+ q_btgh = q.unsqueeze(0).unsqueeze(2) # [1,h,1,Dk]
210
+ k_btgh = (
211
+ k.unsqueeze(0).unsqueeze(0).expand(1, q.shape[0], k.shape[0], k.shape[1])
212
+ ) # [1,h,L,Dk]
213
+ v_btgh = (
214
+ v.unsqueeze(0).unsqueeze(0).expand(1, q.shape[0], v.shape[0], v.shape[1])
215
+ ) # [1,h,L,Dv]
216
+ q_btgh = q_btgh.contiguous()
217
+ k_btgh = k_btgh.contiguous()
218
+ v_btgh = v_btgh.contiguous()
219
+ attn = F.scaled_dot_product_attention(
220
+ q_btgh, k_btgh, v_btgh, is_causal=True
221
+ ) # [1,h,1,Dv]
222
+ out[b, t, g] = attn.squeeze(0).squeeze(1) # [h,Dv]
223
+ log("sel.step", b=int(b), t=int(t), g=int(g), L=int(k.shape[0]))
224
+ else:
225
+ out[b, t, g] = 0.0
226
+ log("sel.step", b=int(b), t=int(t), g=int(g), L=0)
227
+ return out
228
+
229
+
230
+ def sliding_window_attention_masked(
231
+ Q: torch.Tensor, # [B,S,G,h,Dk]
232
+ K: torch.Tensor, # [B,G,S,Dk]
233
+ V: torch.Tensor, # [B,G,S,Dv]
234
+ w: int,
235
+ ) -> torch.Tensor: # [B,S,G,h,Dv]
236
+ # Memory-friendly masked semantics: only the first element in [start..t] is attended.
237
+ # With a single allowed key per row, SDPA reduces to returning that V directly.
238
+ B, S, G, h, Dk = Q.shape
239
+ if w <= 0 or K.shape[2] == 0:
240
+ return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device)
241
+ device = Q.device
242
+ tpos = torch.arange(S, device=device)
243
+ start = (tpos - (w - 1)).clamp_min(0) # [S]
244
+ # Build per-(B,G,S) gather indices and fetch V at start
245
+ idx = start.view(1, 1, S, 1).expand(B, G, S, 1) # [B,G,S,1]
246
+ v_sel = torch.gather(V, 2, idx.expand(B, G, S, V.shape[-1])) # [B,G,S,Dv]
247
+ # Expand across heads; result [B,S,G,h,Dv]
248
+ Of = v_sel.permute(0, 2, 1, 3).unsqueeze(3).expand(B, S, G, h, V.shape[-1])
249
+ return Of
250
+
251
+
252
+ def batched_causal_attention_compressed_masked(
253
+ Q: torch.Tensor, # [B,S,G,h,Dk]
254
+ K_cmp: torch.Tensor, # [B,G,S_cmp,Dk]
255
+ V_cmp: torch.Tensor, # [B,G,S_cmp,Dv]
256
+ l: int,
257
+ d: int,
258
+ ) -> torch.Tensor: # [B,S,G,h,Dv]
259
+ # Memory-friendly masked semantics: if num_cmp(t)>0, attend only to index 0 → return V[:, :, 0].
260
+ B, S, G, h, Dk = Q.shape
261
+ S_cmp = K_cmp.shape[2]
262
+ device = Q.device
263
+ if S_cmp == 0:
264
+ return torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
265
+ tpos = torch.arange(S, device=device)
266
+ num_cmp = torch.where(tpos + 1 < l, 0, ((tpos + 1 - l) // d) + 1).clamp(min=0, max=S_cmp) # [S]
267
+ have_any = (num_cmp > 0).view(1, S, 1, 1, 1).expand(B, S, G, h, 1)
268
+ v0 = V_cmp[:, :, 0, :] # [B,G,Dv]
269
+ v0f = v0.unsqueeze(1).unsqueeze(3).expand(B, S, G, h, V_cmp.shape[-1])
270
+ Of = torch.where(have_any, v0f, torch.zeros_like(v0f))
271
+ return Of
272
+
273
+
274
+ def grouped_selection_attention_packed(
275
+ Q: torch.Tensor, # [B,S,G,h,Dk]
276
+ K: torch.Tensor, # [B,G,S_kv,Dk]
277
+ V: torch.Tensor, # [B,G,S_kv,Dv]
278
+ ranges: torch.Tensor, # [B,S,G,n,2]
279
+ ) -> torch.Tensor: # [B,S,G,h,Dv]
280
+ """
281
+ Bucketed varlen packing by row length L with parity to gather path.
282
+ For each (b,t,g), build its flat index list from ranges, bucket rows
283
+ by identical L, and run one SDPA per bucket.
284
+ """
285
+ B, S, G, h, Dk = Q.shape
286
+ K.shape[2]
287
+ device = Q.device
288
+ # Initialize output
289
+ out = torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=device)
290
+ # Flatten to row indices
291
+ rows = [] # list of (b,t,g, idx_tensor[L])
292
+ lengths = []
293
+ for b in range(B):
294
+ for t in range(S):
295
+ for g in range(G):
296
+ idxs = []
297
+ for i in range(ranges.shape[3]):
298
+ s0 = int(ranges[b, t, g, i, 0].item())
299
+ e0 = int(ranges[b, t, g, i, 1].item())
300
+ if e0 > s0:
301
+ idxs.append(torch.arange(s0, e0, device=device))
302
+ if idxs:
303
+ idx = torch.cat(idxs)
304
+ else:
305
+ idx = torch.empty((0,), dtype=torch.long, device=device)
306
+ rows.append((b, t, g, idx))
307
+ lengths.append(idx.numel())
308
+ if not rows:
309
+ return out
310
+ lengths_t = torch.tensor(lengths, device=device)
311
+ unique_L = torch.unique(lengths_t)
312
+ # Enable autograd-safe packing during training or when forced by env
313
+ use_safe_pack = (
314
+ torch.is_grad_enabled() and (Q.requires_grad or K.requires_grad or V.requires_grad)
315
+ ) or _env_bool("NSA_TRAIN_SAFE_PACK", False)
316
+
317
+ for Lval in unique_L.tolist():
318
+ L = int(Lval)
319
+ # collect row indices for this bucket
320
+ bucket_idx = [i for i, Lx in enumerate(lengths) if Lx == L]
321
+ if L == 0 or len(bucket_idx) == 0:
322
+ # rows with L=0 remain zeros
323
+ continue
324
+ N = len(bucket_idx)
325
+ if use_safe_pack:
326
+ # Graph-friendly packing using stack to preserve autograd links
327
+ map_rows = []
328
+ Q_list = []
329
+ K_list = []
330
+ V_list = []
331
+ for ridx in bucket_idx:
332
+ b, t, g, idx = rows[ridx]
333
+ map_rows.append((b, t, g))
334
+ Q_list.append(Q[b, t, g]) # [h,Dk]
335
+ K_list.append(K[b, g, idx]) # [L,Dk]
336
+ V_list.append(V[b, g, idx]) # [L,Dv]
337
+ Qb = torch.stack(Q_list, dim=0) # [N,h,Dk]
338
+ Kb = torch.stack(K_list, dim=0) # [N,L,Dk]
339
+ Vb = torch.stack(V_list, dim=0) # [N,L,Dv]
340
+ q_btgh = Qb.unsqueeze(1).permute(0, 2, 1, 3) # [N,h,1,Dk]
341
+ k_btgh = Kb.unsqueeze(1).expand(N, h, L, Dk)
342
+ v_btgh = Vb.unsqueeze(1).expand(N, h, L, V.shape[-1])
343
+ attn = F.scaled_dot_product_attention(q_btgh, k_btgh, v_btgh, is_causal=True)
344
+ Ob = attn.squeeze(2) # [N,h,Dv]
345
+ for j, (b, t, g) in enumerate(map_rows):
346
+ out[b, t, g] = Ob[j]
347
+ else:
348
+ # Workspace-backed Q, K, V batches to reduce allocations
349
+ ws_key = (str(device), Q.dtype, K.dtype, V.dtype, h, Dk, V.shape[-1])
350
+ ws = _SEL_PACK_WS.get(ws_key)
351
+ need_new = (
352
+ ws is None or ws["Q"].shape[0] < N or ws["K"].shape[1] < L or ws["V"].shape[1] < L
353
+ )
354
+ if need_new:
355
+ # Allow pre-sizing via env to reduce reallocations
356
+ # Bounded to prevent excessive memory allocation (max 100K rows, 10K length)
357
+ reserve_N = _env_int_bounded("NSA_SEL_PACK_RESERVE_N", 0, 0, 10**5)
358
+ reserve_L = _env_int_bounded("NSA_SEL_PACK_RESERVE_L", 0, 0, 10**4)
359
+ new_N = max(N, reserve_N)
360
+ new_L = max(L, reserve_L)
361
+ Qb = torch.empty((new_N, h, Dk), dtype=Q.dtype, device=device)
362
+ Kb = torch.empty((new_N, new_L, Dk), dtype=K.dtype, device=device)
363
+ Vb = torch.empty((new_N, new_L, V.shape[-1]), dtype=V.dtype, device=device)
364
+ _SEL_PACK_WS[ws_key] = {"Q": Qb, "K": Kb, "V": Vb}
365
+ else:
366
+ Qb = _SEL_PACK_WS[ws_key]["Q"][:N]
367
+ Kb = _SEL_PACK_WS[ws_key]["K"][:N, :L]
368
+ Vb = _SEL_PACK_WS[ws_key]["V"][:N, :L]
369
+ # Populate workspace buffers and perform SDPA (execute for both new and reused workspaces)
370
+ map_rows = []
371
+ for j, ridx in enumerate(bucket_idx):
372
+ b, t, g, idx = rows[ridx]
373
+ Qb[j] = Q[b, t, g] # [h,Dk]
374
+ Kb[j] = K[b, g, idx] # [L,Dk]
375
+ Vb[j] = V[b, g, idx] # [L,Dv]
376
+ map_rows.append((b, t, g))
377
+ # SDPA per bucket: expand per-head
378
+ q_btgh = Qb.unsqueeze(1) # [N,1,h,Dk]
379
+ q_btgh = q_btgh.permute(0, 2, 1, 3) # [N,h,1,Dk]
380
+ k_btgh = Kb.unsqueeze(1).expand(N, h, L, Dk)
381
+ v_btgh = Vb.unsqueeze(1).expand(N, h, L, V.shape[-1])
382
+ attn = F.scaled_dot_product_attention(
383
+ q_btgh, k_btgh, v_btgh, is_causal=True
384
+ ) # [N,h,1,Dv]
385
+ Ob = attn.squeeze(2) # [N,h,Dv]
386
+ # Scatter back
387
+ for j, (b, t, g) in enumerate(map_rows):
388
+ out[b, t, g] = Ob[j]
389
+ return out
390
+
391
+
392
+ def selection_attention_varlen_all(
393
+ Q: torch.Tensor, # [B,S,G,h,Dk]
394
+ K: torch.Tensor, # [B,G,S_kv,Dk]
395
+ V: torch.Tensor, # [B,G,S_kv,Dv]
396
+ ranges: torch.Tensor, # [B,S,G,n,2]
397
+ ) -> torch.Tensor: # [B,S,G,h,Dv]
398
+ """
399
+ Fully batched selection attention using varlen packing across all (B,S,G) rows.
400
+
401
+ If NSA_SEL_VARLEN_V2 is enabled (default), dispatches to the vectorized v2
402
+ packer. Otherwise uses the legacy v1 path (minimal loops with workspace).
403
+ """
404
+ # Optional v2 vectorized packer
405
+ if os.getenv("NSA_SEL_VARLEN_V2", "1").lower() in ("1", "true", "yes", "on"):
406
+ return selection_attention_varlen_all_v2(Q, K, V, ranges)
407
+ B, S, G, h, Dk = Q.shape
408
+ # Parity override: when enabled, force causal=True to match packed reference
409
+ _parity = os.getenv("NSA_SEL_VARLEN_FORCE_PARITY", "0").lower() in ("1", "true", "yes", "on")
410
+ if _parity:
411
+ # Force exact parity by delegating to the packed reference
412
+ return grouped_selection_attention_packed(Q, K, V, ranges)
413
+ device = Q.device
414
+ Dv = V.shape[-1]
415
+ out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
416
+ # Build row list and lengths from ranges (sum of segment lengths)
417
+ rows: list[tuple[int, int, int]] = []
418
+ lens: list[int] = []
419
+ for b in range(B):
420
+ for t in range(S):
421
+ for g in range(G):
422
+ L = 0
423
+ for i in range(ranges.shape[3]):
424
+ s0 = int(ranges[b, t, g, i, 0].item())
425
+ e0 = int(ranges[b, t, g, i, 1].item())
426
+ if e0 > s0:
427
+ L += e0 - s0
428
+ if L > 0:
429
+ rows.append((b, t, g))
430
+ lens.append(L)
431
+ N = len(rows)
432
+ if N == 0:
433
+ return out
434
+
435
+ total_k = int(sum(lens))
436
+ # Workspace-backed packing
437
+ ws = _get_varlen_workspace(
438
+ device,
439
+ dtype_q=Q.dtype,
440
+ dtype_k=K.dtype,
441
+ dtype_v=V.dtype,
442
+ h=h,
443
+ d_k=Dk,
444
+ d_v=Dv,
445
+ cap_N=N,
446
+ cap_total_k=total_k,
447
+ )
448
+ q_pack = ws["q"][:N]
449
+ k_pack = ws["k"][:total_k]
450
+ v_pack = ws["v"][:total_k]
451
+ cuq = ws["cuq"][: N + 1]
452
+ cuk = ws["cuk"][: N + 1]
453
+ # Fill cu_seqlens
454
+ cuq.zero_()
455
+ cuk.zero_()
456
+ # Pack per row
457
+ write_pos = 0
458
+ for i, (b, t, g) in enumerate(rows):
459
+ # q for row
460
+ q_pack[i] = Q[b, t, g]
461
+ # iterate segments for this row
462
+ for j in range(ranges.shape[3]):
463
+ s0 = int(ranges[b, t, g, j, 0].item())
464
+ e0 = int(ranges[b, t, g, j, 1].item())
465
+ if e0 <= s0:
466
+ continue
467
+ seg_k = K[b, g, s0:e0] # [Lseg,Dk]
468
+ seg_v = V[b, g, s0:e0] # [Lseg,Dv]
469
+ Lseg = e0 - s0
470
+ # Assign using explicit expand_as to match target slice shape and avoid view pitfalls
471
+ _kslice = k_pack[write_pos : write_pos + Lseg]
472
+ _vslice = v_pack[write_pos : write_pos + Lseg]
473
+ _kslice.copy_(seg_k[:, None, :].expand_as(_kslice))
474
+ _vslice.copy_(seg_v[:, None, :].expand_as(_vslice))
475
+ write_pos += Lseg
476
+ cuq[i + 1] = cuq[i] + 1
477
+ cuk[i + 1] = cuk[i] + lens[i]
478
+ # Try FA‑2 varlen if available and supported. Default non-causal semantics;
479
+ # optionally force parity with packed path via NSA_SEL_VARLEN_FORCE_PARITY.
480
+ ok, _ = fa2_supported_verbose(device, Q.dtype, Dk)
481
+ if ok and is_flash_varlen_available():
482
+ try:
483
+ o_pack = attention_fa2_varlen(
484
+ q_pack,
485
+ k_pack,
486
+ v_pack,
487
+ cuq,
488
+ cuk,
489
+ max_seqlen_q=1,
490
+ max_seqlen_k=max(lens),
491
+ causal=_parity,
492
+ ) # [N,h,Dv]
493
+ # Scatter back
494
+ for i, (b, t, g) in enumerate(rows):
495
+ out[b, t, g] = o_pack[i]
496
+ return out
497
+ except Exception:
498
+ pass
499
+ # Dense batch per fixed L bucket as fallback
500
+ buckets: dict[int, list[int]] = {}
501
+ for i, L in enumerate(lens):
502
+ buckets.setdefault(L, []).append(i)
503
+ for L, idxs in buckets.items():
504
+ if L <= 0 or len(idxs) == 0:
505
+ continue
506
+ Nb = len(idxs)
507
+ Qb = torch.empty((Nb, h, Dk), dtype=Q.dtype, device=device)
508
+ Kb = torch.empty((Nb, L, Dk), dtype=K.dtype, device=device)
509
+ Vb = torch.empty((Nb, L, Dv), dtype=V.dtype, device=device)
510
+ tgt: list[tuple[int, int, int]] = []
511
+ for j, irow in enumerate(idxs):
512
+ b, t, g = rows[irow]
513
+ Qb[j] = Q[b, t, g]
514
+ # Rebuild fixed-length K/V for this row from ranges
515
+ write = 0
516
+ for rj in range(ranges.shape[3]):
517
+ s0 = int(ranges[b, t, g, rj, 0].item())
518
+ e0 = int(ranges[b, t, g, rj, 1].item())
519
+ if e0 <= s0:
520
+ continue
521
+ Lseg = e0 - s0
522
+ Kb[j, write : write + Lseg] = K[b, g, s0:e0]
523
+ Vb[j, write : write + Lseg] = V[b, g, s0:e0]
524
+ write += Lseg
525
+ tgt.append((b, t, g))
526
+ # Batched dense fallback for this bucket. Default non-causal; optionally force parity.
527
+ try:
528
+ q_rows = Qb.unsqueeze(1) # [Nb,1,h,Dk]
529
+ k_rows = Kb.unsqueeze(2).expand(Nb, L, h, Dk) # [Nb,L,h,Dk]
530
+ v_rows = Vb.unsqueeze(2).expand(Nb, L, h, Dv) # [Nb,L,h,Dv]
531
+ Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=_parity).squeeze(
532
+ 1
533
+ ) # [Nb,h,Dv]
534
+ for i, (b, t, g) in enumerate(tgt):
535
+ out[b, t, g] = Ob[i]
536
+ except Exception:
537
+ # Final fallback: per-row SDPA
538
+ for j, (b, t, g) in enumerate(tgt):
539
+ q_btgh = Qb[j].unsqueeze(0).unsqueeze(0) # [1,1,h,Dk]
540
+ k_btgh = Kb[j].unsqueeze(0).unsqueeze(0) # [1,1,L,Dk]
541
+ v_btgh = Vb[j].unsqueeze(0).unsqueeze(0) # [1,1,L,Dv]
542
+ out[b, t, g] = attention_bgh(q_btgh, k_btgh, v_btgh, causal=_parity)[0, 0]
543
+ return out
544
+
545
+
546
+ def selection_attention_varlen_all_v2(
547
+ Q: torch.Tensor,
548
+ K: torch.Tensor,
549
+ V: torch.Tensor,
550
+ ranges: torch.Tensor,
551
+ ) -> torch.Tensor:
552
+ """
553
+ Vectorized v2 varlen selection packer with FA‑2 varlen fast path and dense fallback.
554
+ - Eliminates Python loops for packing by using a difference-array mask to build per-row
555
+ allowed indices and flat-select K/V tokens.
556
+ - Uses causal=False for single‑query rows.
557
+ - Env: NSA_SEL_VARLEN_MIN_L to bypass on tiny rows (falls back to packed path).
558
+ """
559
+ B, S, G, h, Dk = Q.shape
560
+ # Parity override: when enabled, force causal=True to match packed reference
561
+ _parity = os.getenv("NSA_SEL_VARLEN_FORCE_PARITY", "0").lower() in ("1", "true", "yes", "on")
562
+ if _parity:
563
+ # Force exact parity by delegating to the packed reference
564
+ return grouped_selection_attention_packed(Q, K, V, ranges)
565
+ device = Q.device
566
+ Dv = V.shape[-1]
567
+ S_kv = K.shape[2]
568
+ out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
569
+ if S_kv == 0:
570
+ return out
571
+
572
+ # Build allowed mask [B,S,G,S_kv]
573
+ n = ranges.shape[3]
574
+ starts = ranges[..., 0].to(torch.int64).clamp_(0, S_kv)
575
+ ends = ranges[..., 1].to(torch.int64).clamp_(0, S_kv)
576
+ BSG = B * S * G
577
+ starts_f = starts.reshape(BSG, n)
578
+ ends_f = ends.reshape(BSG, n)
579
+ diff = torch.zeros((BSG, S_kv + 1), dtype=torch.int32, device=device)
580
+ one = torch.ones_like(starts_f, dtype=diff.dtype, device=device)
581
+ diff.scatter_add_(1, starts_f, one)
582
+ diff.scatter_add_(1, ends_f, -one)
583
+ allowed = diff[:, :-1].cumsum(dim=1).gt(0) # [BSG,S_kv]
584
+
585
+ lens_flat = allowed.sum(dim=1, dtype=torch.int32) # [BSG]
586
+ row_mask = lens_flat.gt(0)
587
+ if not torch.any(row_mask):
588
+ return out
589
+ try:
590
+ min_L = int(os.getenv("NSA_SEL_VARLEN_MIN_L", "0"))
591
+ except Exception:
592
+ min_L = 0
593
+ if min_L > 0 and int(lens_flat.max().item()) < min_L:
594
+ return grouped_selection_attention_packed(Q, K, V, ranges)
595
+
596
+ idx_rows = torch.nonzero(row_mask, as_tuple=False).squeeze(1) # [N]
597
+ N = int(idx_rows.numel())
598
+ # (b,t,g) indices for scatter
599
+ b_idx = idx_rows // (S * G)
600
+ rem = idx_rows % (S * G)
601
+ t_idx = rem // G
602
+ g_idx = rem % G
603
+
604
+ # Pack Q rows
605
+ Q_rows = Q.reshape(B * S * G, h, Dk)[idx_rows]
606
+
607
+ # Map rows to b,g to select K/V
608
+ bg_map = (
609
+ torch.arange(B, device=device).view(B, 1, 1) * G
610
+ + torch.arange(G, device=device).view(1, 1, G)
611
+ ).expand(B, S, G)
612
+ bg_rows = bg_map.reshape(B * S * G)[idx_rows]
613
+ K_bg = K.reshape(B * G, S_kv, Dk)[bg_rows]
614
+ V_bg = V.reshape(B * G, S_kv, Dv)[bg_rows]
615
+ allowed_rows = allowed[idx_rows]
616
+
617
+ total_k = int(lens_flat[row_mask].sum().item())
618
+ sel_k = K_bg[allowed_rows] # [total_k, Dk]
619
+ sel_v = V_bg[allowed_rows] # [total_k, Dv]
620
+ lens_sel = lens_flat[row_mask] # [N]
621
+
622
+ # Workspace-backed packing
623
+ ws = _get_varlen_workspace(
624
+ device,
625
+ dtype_q=Q.dtype,
626
+ dtype_k=K.dtype,
627
+ dtype_v=V.dtype,
628
+ h=h,
629
+ d_k=Dk,
630
+ d_v=Dv,
631
+ cap_N=N,
632
+ cap_total_k=total_k,
633
+ )
634
+ q_pack = ws["q"][:N]
635
+ k_pack = ws["k"][:total_k]
636
+ v_pack = ws["v"][:total_k]
637
+ cuq = ws["cuq"][: N + 1]
638
+ cuk = ws["cuk"][: N + 1]
639
+
640
+ q_pack.copy_(Q_rows)
641
+ k_pack.copy_(sel_k.unsqueeze(1).expand(total_k, h, Dk))
642
+ v_pack.copy_(sel_v.unsqueeze(1).expand(total_k, h, Dv))
643
+ cuq.copy_(torch.arange(0, N + 1, device=device, dtype=torch.int32))
644
+ cuk[0] = 0
645
+ torch.cumsum(lens_sel.to(torch.int32), dim=0, out=cuk[1:])
646
+
647
+ # FA‑2 varlen (non-causal)
648
+ ok, _why = fa2_supported_verbose(device, Q.dtype, Dk)
649
+ max_len = int(lens_sel.max().item())
650
+ if ok and is_flash_varlen_available():
651
+ try:
652
+ o_pack = attention_fa2_varlen(
653
+ q_pack,
654
+ k_pack,
655
+ v_pack,
656
+ cuq,
657
+ cuk,
658
+ max_seqlen_q=1,
659
+ max_seqlen_k=max_len,
660
+ causal=_parity,
661
+ )
662
+ out[b_idx, t_idx, g_idx] = o_pack
663
+ return out
664
+ except Exception:
665
+ pass
666
+
667
+ # Correctness-first fallback: masked SDPA over an allowed key mask
668
+ # This path matches the non-causal packed reference exactly and avoids
669
+ # potential packing/indexing pitfalls in dense-bucket fallbacks.
670
+ try:
671
+ return grouped_selection_attention_masked(Q, K, V, ranges)
672
+ except Exception:
673
+ pass
674
+
675
+ # Legacy dense fallback by length buckets (kept as a final fallback)
676
+ starts = cuk[:-1].to(torch.int64)
677
+ ends = cuk[1:].to(torch.int64)
678
+ Ls = (ends - starts).to(torch.int64)
679
+ for L in torch.unique(Ls).tolist():
680
+ if L <= 0:
681
+ continue
682
+ sel = (Ls == L).nonzero(as_tuple=False).squeeze(1)
683
+ if sel.numel() == 0:
684
+ continue
685
+ Nb = int(sel.numel())
686
+ Qb = q_pack[sel]
687
+ k_rows = torch.empty((Nb, L, h, Dk), dtype=K.dtype, device=device)
688
+ v_rows = torch.empty((Nb, L, h, Dv), dtype=V.dtype, device=device)
689
+ for j in range(Nb):
690
+ s0 = int(starts[sel[j]].item())
691
+ e0 = int(ends[sel[j]].item())
692
+ k_rows[j] = k_pack[s0:e0]
693
+ v_rows[j] = v_pack[s0:e0]
694
+ try:
695
+ Ob = attention_fa2_dense_batch(Qb.unsqueeze(1), k_rows, v_rows, causal=_parity).squeeze(1)
696
+ except Exception:
697
+ Ob = torch.empty((Nb, h, Dv), dtype=V.dtype, device=device)
698
+ for j in range(Nb):
699
+ Ob[j] = attention_bgh(Qb[j].unsqueeze(0), k_rows[j].unsqueeze(0), v_rows[j].unsqueeze(0), causal=_parity)[
700
+ 0
701
+ ]
702
+ out[b_idx[sel], t_idx[sel], g_idx[sel]] = Ob
703
+ return out
704
+
705
+
706
+ def grouped_selection_attention_masked(
707
+ Q: torch.Tensor, # [B,S,G,h,Dk]
708
+ K: torch.Tensor, # [B,G,S_kv,Dk]
709
+ V: torch.Tensor, # [B,G,S_kv,Dv]
710
+ ranges: torch.Tensor, # [B,S,G,n,2]
711
+ ) -> torch.Tensor: # [B,S,G,h,Dv]
712
+ """
713
+ Fully batched selection attention using an additive -inf mask.
714
+ Vectorized ranges→mask construction via prefix-sum trick (no Python loops).
715
+ """
716
+ B, S, G, h, Dk = Q.shape
717
+ S_kv = K.shape[2]
718
+ device = Q.device
719
+ if S_kv == 0:
720
+ return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=device)
721
+
722
+ # Vectorized allowed mask [B,S,G,S_kv] from ranges using difference array
723
+ n = ranges.shape[3]
724
+ starts = ranges[..., 0].to(torch.int64).clamp_(0, S_kv) # [B,S,G,n]
725
+ ends = ranges[..., 1].to(torch.int64).clamp_(0, S_kv) # [B,S,G,n]
726
+ BSG = B * S * G
727
+ starts_f = starts.reshape(BSG, n)
728
+ ends_f = ends.reshape(BSG, n)
729
+ diff = torch.zeros((BSG, S_kv + 1), dtype=torch.int32, device=device)
730
+ one = torch.ones_like(starts_f, dtype=diff.dtype, device=device)
731
+ diff.scatter_add_(1, starts_f, one)
732
+ diff.scatter_add_(1, ends_f, -one)
733
+ allowed = diff[:, :-1].cumsum(dim=1).gt(0).reshape(B, S, G, S_kv)
734
+
735
+ # Detect rows with no allowed keys (all False along key dimension)
736
+ row_has_any = allowed.any(dim=-1) # [B,S,G]
737
+ row_empty = ~row_has_any
738
+
739
+ # Prevent SDPA from seeing an all-−inf row which can produce NaNs.
740
+ # For originally empty rows, force a single safe key (index 0) to True,
741
+ # run SDPA, then zero their outputs afterward to preserve semantics.
742
+ if row_empty.any():
743
+ allowed_safe = allowed.clone()
744
+ flat = allowed_safe.view(B * S * G, S_kv)
745
+ row_empty_flat = row_empty.reshape(B * S * G)
746
+ if S_kv > 0:
747
+ flat[row_empty_flat, 0] = True
748
+ allowed_safe = flat.view_as(allowed_safe)
749
+ else:
750
+ allowed_safe = allowed
751
+
752
+ # Prepare SDPA tensors: [B,G*h,S, D*] and mask [B,G*h,S,S_kv]
753
+ Qf = Q.reshape(B, S, G * h, Dk).transpose(1, 2).contiguous() # [B,G*h,S,Dk]
754
+ Kf = K.unsqueeze(2).expand(-1, -1, h, -1, -1).reshape(B, G * h, S_kv, Dk).contiguous()
755
+ Vf = V.unsqueeze(2).expand(-1, -1, h, -1, -1).reshape(B, G * h, S_kv, V.shape[-1]).contiguous()
756
+ # Build additive mask in float32 for numerical stability with -inf
757
+ zeros = torch.zeros((B, G * h, S, S_kv), dtype=torch.float32, device=device)
758
+ neg_inf = torch.full((B, G * h, S, S_kv), float("-inf"), dtype=torch.float32, device=device)
759
+ Mf = torch.where(
760
+ allowed_safe.transpose(1, 2) # [B,G,S,S_kv]
761
+ .unsqueeze(2)
762
+ .expand(-1, -1, h, -1, -1)
763
+ .reshape(B, G * h, S, S_kv),
764
+ zeros,
765
+ neg_inf,
766
+ ).contiguous()
767
+
768
+ Of = F.scaled_dot_product_attention(Qf, Kf, Vf, attn_mask=Mf) # [B,G*h,S,Dv]
769
+ Of = Of.transpose(1, 2).reshape(B, S, G, h, V.shape[-1])
770
+ # Zero outputs for originally empty rows to preserve semantics
771
+ if row_empty.any():
772
+ Of = torch.where(row_has_any.unsqueeze(-1).unsqueeze(-1), Of, torch.zeros_like(Of))
773
+ return Of
774
+
775
+
776
+ # ===== FA-2 integration scaffolding (M1) =====
777
+
778
+
779
+ def _env_bool(name: str, default: bool = False) -> bool:
780
+ v = os.getenv(name, "1" if default else "0").lower()
781
+ return v in ("1", "true", "yes", "on")
782
+
783
+
784
+ def _is_sm89(device: torch.device) -> bool:
785
+ """Return True if running on CUDA device with SM 8.9 (Ada/RTX 4090)."""
786
+ if device.type != "cuda":
787
+ return False
788
+ try:
789
+ cap = torch.cuda.get_device_capability(device)
790
+ return cap == (8, 9)
791
+ except Exception:
792
+ return False
793
+
794
+
795
+ def _fa2_forced() -> bool:
796
+ """Return True if FA-2 usage is explicitly forced via env."""
797
+ return _env_bool("NSA_FA2_FORCE", False)
798
+
799
+
800
+ def sliding_window_attention_fa2(
801
+ Q: torch.Tensor, # [B,S,G,h,Dk]
802
+ K: torch.Tensor, # [B,G,S,Dk]
803
+ V: torch.Tensor, # [B,G,S,Dv]
804
+ w: int,
805
+ min_len_for_fa2: int = 16,
806
+ ) -> torch.Tensor:
807
+ """
808
+ Planned FA-2 path for sliding with safe fallbacks.
809
+ Currently falls back to masked SDPA to preserve numerics until FA-2 is wired.
810
+ """
811
+ B, S, G, h, Dk = Q.shape
812
+ device = Q.device
813
+ # Policy: sliding FA-2 is disabled by default due to API semantics
814
+ # limitation (causal mask assumes start at 0). Allow only if explicitly
815
+ # enabled via NSA_ALLOW_SLIDING_FA2 or forced flags.
816
+ allow_sliding_fa2 = _env_bool("NSA_ALLOW_SLIDING_FA2", False)
817
+ # Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
818
+ if _is_sm89(device) and not _fa2_forced():
819
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
820
+ log("fa2.gate_skip", branch="win", reason="sm89_guard", forced=bool(_fa2_forced()))
821
+ return sliding_window_attention(Q, K, V, w)
822
+ # Policy guard
823
+ if not allow_sliding_fa2 and not (
824
+ _env_bool("NSA_FA2_FORCE_VARLEN", False) or _env_bool("NSA_FA2_FORCE_DENSE", False)
825
+ ):
826
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
827
+ log("fa2.gate_skip", branch="win", reason="unsupported_sliding_semantics", forced=False)
828
+ return sliding_window_attention(Q, K, V, w)
829
+ # Compute effective per-row window lengths and buckets
830
+ lengths = compute_sliding_lengths(S, w, device)
831
+ max_len = int(lengths.max().item()) if lengths.numel() > 0 else 0
832
+ # Allow override via env
833
+ try:
834
+ min_len_for_fa2 = int(os.getenv("NSA_FA2_MIN_LEN_WIN", str(min_len_for_fa2)))
835
+ except Exception:
836
+ pass
837
+ # Disable sentinel: non-positive threshold disables FA‑2 entirely for this branch
838
+ if min_len_for_fa2 <= 0:
839
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
840
+ log("fa2.gate_skip", branch="win", reason="disabled_threshold")
841
+ return sliding_window_attention(Q, K, V, w)
842
+ buckets = build_length_buckets(lengths)
843
+ if buckets:
844
+ log("fa2.win.buckets", n=len(buckets), max_len=max_len)
845
+ # Build cu_seqlens per bucket (for future FA-2 varlen call)
846
+ for idx in buckets:
847
+ blens = lengths[idx]
848
+ _ = build_cu_seqlens_for_buckets(blens)
849
+ # Small-length auto-switch to masked SDPA
850
+ if max_len < min_len_for_fa2:
851
+ if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
852
+ log(
853
+ "fa2.gate_skip",
854
+ branch="win",
855
+ reason="below_min_len",
856
+ max_len=int(max_len),
857
+ min_len=int(min_len_for_fa2),
858
+ )
859
+ return sliding_window_attention(Q, K, V, w)
860
+ # Capability check
861
+ ok, why = fa2_supported_verbose(device, Q.dtype, Dk)
862
+ if not ok or not is_flash_varlen_available():
863
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
864
+ log("fa2.gate_skip", branch="win", reason=why, has_varlen=is_flash_varlen_available())
865
+ return sliding_window_attention(Q, K, V, w)
866
+ # Attempt FA-2 across all rows using varlen first, then dense per-bucket. Fallback to masked SDPA on error.
867
+ try:
868
+ B, S, G, h, Dk = Q.shape
869
+ Dv = V.shape[-1]
870
+ use_timing = os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes")
871
+ force_varlen = _env_bool("NSA_FA2_FORCE_VARLEN", False)
872
+ force_dense = _env_bool("NSA_FA2_FORCE_DENSE", False)
873
+ force_win_dense = _env_bool("NSA_WIN_FORCE_DENSE", False)
874
+ # Log histogram of lengths
875
+ if buckets:
876
+ uniq, counts = torch.unique(lengths, return_counts=True)
877
+ log("fa2.win.hist", uniq=uniq.tolist(), counts=counts.tolist())
878
+ # Try a single varlen call across all rows
879
+ if (is_flash_varlen_available() and not (force_dense or force_win_dense)) or force_varlen:
880
+ rows = []
881
+ len_rows = []
882
+ for t in range(S):
883
+ L = int(lengths[t].item())
884
+ for b in range(B):
885
+ for g in range(G):
886
+ rows.append((b, t, g))
887
+ len_rows.append(L)
888
+ N = len(rows)
889
+ if N > 0 and max_len >= 1:
890
+ use_safe_pack = (
891
+ torch.is_grad_enabled()
892
+ and (Q.requires_grad or K.requires_grad or V.requires_grad)
893
+ ) or _env_bool("NSA_TRAIN_SAFE_PACK", False)
894
+ if use_safe_pack:
895
+ # Autograd-safe packing via stack/cat to preserve graph links
896
+ q_pack = torch.stack([Q[b, t, g] for (b, t, g) in rows], dim=0) # [N,h,Dk]
897
+ k_rows = []
898
+ v_rows = []
899
+ for i, (b, t, g) in enumerate(rows):
900
+ L = len_rows[i]
901
+ if L > 0:
902
+ start = max(0, (t + 1) - w)
903
+ end = t + 1
904
+ seg_k = K[b, g, start:end].unsqueeze(1).expand(-1, h, -1) # [L,h,Dk]
905
+ seg_v = V[b, g, start:end].unsqueeze(1).expand(-1, h, -1) # [L,h,Dv]
906
+ k_rows.append(seg_k)
907
+ v_rows.append(seg_v)
908
+ total_k = int(sum(len_rows))
909
+ if total_k > 0:
910
+ k_pack = torch.cat(k_rows, dim=0)
911
+ v_pack = torch.cat(v_rows, dim=0)
912
+ else:
913
+ k_pack = torch.zeros((0, h, Dk), dtype=K.dtype, device=K.device)
914
+ v_pack = torch.zeros((0, h, Dv), dtype=V.dtype, device=V.device)
915
+ cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
916
+ lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
917
+ cuk = torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0)
918
+ else:
919
+ total_k = int(sum(len_rows))
920
+ ws = _get_varlen_workspace(
921
+ Q.device, Q.dtype, K.dtype, V.dtype, h, Dk, Dv, N, total_k
922
+ )
923
+ q_pack = ws["q"][:N]
924
+ k_pack = ws["k"][:total_k]
925
+ v_pack = ws["v"][:total_k]
926
+ # Build cumulative sequence lengths for Q and K
927
+ cuq = ws["cuq"][: N + 1]
928
+ cuq.copy_(torch.arange(0, N + 1, device=Q.device, dtype=torch.int32))
929
+ lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
930
+ cuk = ws["cuk"][: N + 1]
931
+ torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0, out=cuk)
932
+ # Fill packs
933
+ write_pos = 0
934
+ for i, (b, t, g) in enumerate(rows):
935
+ L = len_rows[i]
936
+ q_pack[i] = Q[b, t, g]
937
+ if L > 0:
938
+ start = max(0, (t + 1) - w)
939
+ end = t + 1
940
+ seg_k = K[b, g, start:end] # [L,Dk]
941
+ seg_v = V[b, g, start:end] # [L,Dv]
942
+ assert (write_pos + L) <= total_k, "varlen K/V pack overflow"
943
+ k_pack[write_pos : write_pos + L] = seg_k.unsqueeze(1).expand(L, h, Dk)
944
+ v_pack[write_pos : write_pos + L] = seg_v.unsqueeze(1).expand(L, h, Dv)
945
+ write_pos += L
946
+ # Optional integrity checks (debug only)
947
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
948
+ try:
949
+ assert cuq.numel() == (N + 1), "cuq length mismatch"
950
+ assert cuk.numel() == (N + 1), "cuk length mismatch"
951
+ assert int(cuk[-1].item()) == int(total_k), "cuk total_k mismatch"
952
+ if total_k > 0 and N > 0:
953
+ probe = [0, N // 2, N - 1] if N >= 3 else [0]
954
+ for i in probe:
955
+ L_i = int(len_rows[i])
956
+ b_i, t_i, g_i = rows[i]
957
+ s_i = int(max(0, (t_i + 1) - w))
958
+ e_i = int(t_i + 1)
959
+ if L_i > 0:
960
+ ks = k_pack[cuk[i] : cuk[i + 1]] # [L,h,Dk]
961
+ kv = K[b_i, g_i, s_i:e_i].unsqueeze(1).expand(-1, h, -1)
962
+ if ks.shape != kv.shape:
963
+ log(
964
+ "warn.fa2_win_pack_shape",
965
+ row=i,
966
+ ks=ks.shape,
967
+ kv=kv.shape,
968
+ )
969
+ else:
970
+ md = float((ks - kv).abs().max().item())
971
+ if md > 1e-3:
972
+ log(
973
+ "warn.fa2_win_pack_mismatch",
974
+ row=i,
975
+ L=L_i,
976
+ max_diff=md,
977
+ )
978
+ except Exception:
979
+ pass
980
+
981
+ if use_timing:
982
+ t0 = time.perf_counter()
983
+ o_pack = attention_fa2_varlen(
984
+ q_pack,
985
+ k_pack,
986
+ v_pack,
987
+ cuq,
988
+ cuk,
989
+ max_seqlen_q=1,
990
+ max_seqlen_k=max_len,
991
+ causal=False,
992
+ ) # [N,h,Dv]
993
+ if not torch.isfinite(o_pack).all():
994
+ log("warn.fa2_win_varlen_nonfinite")
995
+ return sliding_window_attention(Q, K, V, w)
996
+ if use_timing:
997
+ dt = (time.perf_counter() - t0) * 1e3
998
+ log("fa2.win.varlen_all", N=int(N), total_k=int(total_k), ms=dt)
999
+ # Scatter back
1000
+ out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
1001
+ for i, (b, t, g) in enumerate(rows):
1002
+ out[b, t, g] = o_pack[i]
1003
+ return out
1004
+ out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
1005
+ for idx in buckets:
1006
+ if idx.numel() == 0:
1007
+ continue
1008
+ L = int(lengths[idx[0]].item())
1009
+ # Collect rows for this bucket
1010
+ rows_q = [] # [N,h,Dk]
1011
+ rows_k = [] # [N,L,Dk]
1012
+ rows_v = [] # [N,L,Dv]
1013
+ tgt = []
1014
+ for t in idx.tolist():
1015
+ start = max(0, (t + 1) - w)
1016
+ end = t + 1
1017
+ for b in range(B):
1018
+ for g in range(G):
1019
+ rows_q.append(Q[b, t, g])
1020
+ rows_k.append(K[b, g, start:end])
1021
+ rows_v.append(V[b, g, start:end])
1022
+ tgt.append((b, t, g))
1023
+ if not rows_q:
1024
+ continue
1025
+ N = len(rows_q)
1026
+ Qb = torch.stack(rows_q, dim=0) # [N,h,Dk]
1027
+ Kb = torch.stack(rows_k, dim=0) # [N,L,Dk]
1028
+ Vb = torch.stack(rows_v, dim=0) # [N,L,Dv]
1029
+ if is_flash_varlen_available() and not (force_dense or force_win_dense):
1030
+ # Pack varlen (constant L here, but use API for generality)
1031
+ q_pack = Qb # [N,h,Dk]
1032
+ k_pack = Kb.reshape(N * L, Dk).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dk)
1033
+ v_pack = Vb.reshape(N * L, Dv).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dv)
1034
+ cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
1035
+ cuk = torch.arange(0, (N + 1) * L, step=L, device=Q.device, dtype=torch.int32)
1036
+ if use_timing:
1037
+ t0 = time.perf_counter()
1038
+ o_pack = attention_fa2_varlen(
1039
+ q_pack,
1040
+ k_pack,
1041
+ v_pack,
1042
+ cuq,
1043
+ cuk,
1044
+ max_seqlen_q=1,
1045
+ max_seqlen_k=L,
1046
+ causal=False,
1047
+ ) # [N,h,Dv]
1048
+ if not torch.isfinite(o_pack).all():
1049
+ log("warn.fa2_win_bucket_nonfinite")
1050
+ return sliding_window_attention(Q, K, V, w)
1051
+ if use_timing:
1052
+ dt = (time.perf_counter() - t0) * 1e3
1053
+ log("fa2.win.bucket", path="varlen", L=L, N=int(N), ms=dt)
1054
+ Ob = o_pack # [N,h,Dv]
1055
+ else:
1056
+ q_rows = Qb.unsqueeze(1) # [N,1,h,Dk]
1057
+ k_rows = Kb.unsqueeze(2).expand(N, L, h, Dk)
1058
+ v_rows = Vb.unsqueeze(2).expand(N, L, h, Dv)
1059
+ if use_timing:
1060
+ t0 = time.perf_counter()
1061
+ Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False).squeeze(
1062
+ 1
1063
+ ) # [N,h,Dv]
1064
+ if use_timing:
1065
+ dt = (time.perf_counter() - t0) * 1e3
1066
+ log("fa2.win.bucket", path="dense", L=L, N=int(N), ms=dt)
1067
+ for i, (b, t, g) in enumerate(tgt):
1068
+ out[b, t, g] = Ob[i]
1069
+ return out
1070
+ except Exception as e:
1071
+ log("warn.fa2_unexpected_fallback", branch="win", error=str(e)[:100])
1072
+ return sliding_window_attention_masked(Q, K, V, w)
1073
+
1074
+
1075
+ def compressed_attention_fa2(
1076
+ Q: torch.Tensor, # [B,S,G,h,Dk]
1077
+ K_cmp: torch.Tensor, # [B,G,S_cmp,Dk]
1078
+ V_cmp: torch.Tensor, # [B,G,S_cmp,Dv]
1079
+ l: int,
1080
+ d: int,
1081
+ min_len_for_fa2: int = 16,
1082
+ ) -> torch.Tensor:
1083
+ """
1084
+ Planned FA-2 path for compressed with safe fallbacks.
1085
+ Currently falls back to masked SDPA to preserve numerics until FA-2 is wired.
1086
+ """
1087
+ B, S, G, h, Dk = Q.shape
1088
+ device = Q.device
1089
+ # Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
1090
+ if _is_sm89(device) and not _fa2_forced():
1091
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
1092
+ log("fa2.gate_skip", branch="cmp", reason="sm89_guard", forced=bool(_fa2_forced()))
1093
+ return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
1094
+ S_cmp = K_cmp.shape[2]
1095
+ if S_cmp == 0:
1096
+ return torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
1097
+ num_cmp = compute_compressed_lengths(S, l, d, S_cmp, device)
1098
+ max_len = int(num_cmp.max().item()) if num_cmp.numel() > 0 else 0
1099
+ try:
1100
+ min_len_for_fa2 = int(os.getenv("NSA_FA2_MIN_LEN_CMP", str(min_len_for_fa2)))
1101
+ except Exception:
1102
+ pass
1103
+ # Disable sentinel: non-positive threshold disables FA‑2 entirely for this branch
1104
+ if min_len_for_fa2 <= 0:
1105
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
1106
+ log("fa2.gate_skip", branch="cmp", reason="disabled_threshold")
1107
+ return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
1108
+ buckets = build_length_buckets(num_cmp)
1109
+ if buckets:
1110
+ log("fa2.cmp.buckets", n=len(buckets), max_len=max_len)
1111
+ for idx in buckets:
1112
+ blens = num_cmp[idx]
1113
+ _ = build_cu_seqlens_for_buckets(blens)
1114
+ if max_len < min_len_for_fa2:
1115
+ if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
1116
+ log(
1117
+ "fa2.gate_skip",
1118
+ branch="cmp",
1119
+ reason="below_min_len",
1120
+ max_len=int(max_len),
1121
+ min_len=int(min_len_for_fa2),
1122
+ )
1123
+ return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
1124
+ ok, why = fa2_supported_verbose(device, Q.dtype, Dk)
1125
+ if not ok or not is_flash_varlen_available():
1126
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
1127
+ log("fa2.gate_skip", branch="cmp", reason=why, has_varlen=is_flash_varlen_available())
1128
+ return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
1129
+ try:
1130
+ Dv = V_cmp.shape[-1]
1131
+ use_timing = os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes")
1132
+ # Log histogram of lengths
1133
+ if buckets:
1134
+ uniq, counts = torch.unique(num_cmp, return_counts=True)
1135
+ log("fa2.cmp.hist", uniq=uniq.tolist(), counts=counts.tolist())
1136
+ # Try single varlen across all rows with L>0
1137
+ force_varlen = _env_bool("NSA_FA2_FORCE_VARLEN", False)
1138
+ force_dense = _env_bool("NSA_FA2_FORCE_DENSE", False)
1139
+ if ((is_flash_varlen_available() and not force_dense) or force_varlen) and max_len >= 1:
1140
+ rows = []
1141
+ len_rows = []
1142
+ for t in range(S):
1143
+ L = int(num_cmp[t].item())
1144
+ for b in range(B):
1145
+ for g in range(G):
1146
+ if L > 0:
1147
+ rows.append((b, t, g))
1148
+ len_rows.append(L)
1149
+ N = len(rows)
1150
+ if N > 0:
1151
+ total_k = int(sum(len_rows))
1152
+ use_safe_pack = (
1153
+ torch.is_grad_enabled()
1154
+ and (Q.requires_grad or K_cmp.requires_grad or V_cmp.requires_grad)
1155
+ ) or _env_bool("NSA_TRAIN_SAFE_PACK", False)
1156
+ if use_safe_pack:
1157
+ q_pack = torch.stack([Q[b, t, g] for (b, t, g) in rows], dim=0)
1158
+ k_rows = []
1159
+ v_rows = []
1160
+ for (b, t, g), L in zip(rows, len_rows):
1161
+ if L > 0:
1162
+ seg_k = K_cmp[b, g, :L]
1163
+ seg_v = V_cmp[b, g, :L]
1164
+ k_rows.append(seg_k.unsqueeze(1).expand(-1, h, -1)) # [L,h,Dk]
1165
+ v_rows.append(seg_v.unsqueeze(1).expand(-1, h, -1)) # [L,h,Dv]
1166
+ if total_k > 0:
1167
+ k_pack = torch.cat(k_rows, dim=0)
1168
+ v_pack = torch.cat(v_rows, dim=0)
1169
+ else:
1170
+ k_pack = torch.zeros((0, h, Dk), dtype=K_cmp.dtype, device=K_cmp.device)
1171
+ v_pack = torch.zeros((0, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device)
1172
+ cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
1173
+ lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
1174
+ cuk = torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0)
1175
+ else:
1176
+ ws = _get_varlen_workspace(
1177
+ Q.device, Q.dtype, K_cmp.dtype, V_cmp.dtype, h, Dk, Dv, N, total_k
1178
+ )
1179
+ q_pack = ws["q"][:N]
1180
+ k_pack = ws["k"][:total_k]
1181
+ v_pack = ws["v"][:total_k]
1182
+ cuq = ws["cuq"][: N + 1]
1183
+ cuq.copy_(torch.arange(0, N + 1, device=Q.device, dtype=torch.int32))
1184
+ lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
1185
+ cuk = ws["cuk"][: N + 1]
1186
+ torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0, out=cuk)
1187
+ write_pos = 0
1188
+ for i, (b, t, g) in enumerate(rows):
1189
+ L = len_rows[i]
1190
+ q_pack[i] = Q[b, t, g]
1191
+ if L > 0:
1192
+ seg_k = K_cmp[b, g, :L]
1193
+ seg_v = V_cmp[b, g, :L]
1194
+ assert (write_pos + L) <= total_k, "varlen cmp K/V pack overflow"
1195
+ k_pack[write_pos : write_pos + L] = seg_k.unsqueeze(1).expand(L, h, Dk)
1196
+ v_pack[write_pos : write_pos + L] = seg_v.unsqueeze(1).expand(L, h, Dv)
1197
+ write_pos += L
1198
+ if use_timing:
1199
+ t0 = time.perf_counter()
1200
+ o_pack = attention_fa2_varlen(
1201
+ q_pack,
1202
+ k_pack,
1203
+ v_pack,
1204
+ cuq,
1205
+ cuk,
1206
+ max_seqlen_q=1,
1207
+ max_seqlen_k=max_len,
1208
+ causal=False,
1209
+ ) # [N,h,Dv]
1210
+ if not torch.isfinite(o_pack).all():
1211
+ log("warn.fa2_cmp_varlen_nonfinite")
1212
+ return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
1213
+ if use_timing:
1214
+ dt = (time.perf_counter() - t0) * 1e3
1215
+ log("fa2.cmp.varlen_all", N=int(N), total_k=int(total_k), ms=dt)
1216
+ out = torch.zeros((B, S, G, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device)
1217
+ for i, (b, t, g) in enumerate(rows):
1218
+ out[b, t, g] = o_pack[i]
1219
+ return out
1220
+ out = torch.zeros((B, S, G, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device)
1221
+ for idx in buckets:
1222
+ if idx.numel() == 0:
1223
+ continue
1224
+ L = int(num_cmp[idx[0]].item())
1225
+ rows_q = [] # [N,h,Dk]
1226
+ rows_k = [] # [N,L,Dk]
1227
+ rows_v = [] # [N,L, Dv]
1228
+ tgt = []
1229
+ for t in idx.tolist():
1230
+ if L <= 0:
1231
+ continue
1232
+ for b in range(B):
1233
+ for g in range(G):
1234
+ rows_q.append(Q[b, t, g])
1235
+ rows_k.append(K_cmp[b, g, :L])
1236
+ rows_v.append(V_cmp[b, g, :L])
1237
+ tgt.append((b, t, g))
1238
+ if not rows_q:
1239
+ continue
1240
+ N = len(rows_q)
1241
+ Qb = torch.stack(rows_q, dim=0) # [N,h,Dk]
1242
+ Kb = torch.stack(rows_k, dim=0) # [N,L,Dk]
1243
+ Vb = torch.stack(rows_v, dim=0) # [N,L,Dv]
1244
+ if is_flash_varlen_available() and not force_dense:
1245
+ q_pack = Qb
1246
+ k_pack = Kb.reshape(N * L, Dk).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dk)
1247
+ v_pack = Vb.reshape(N * L, Dv).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dv)
1248
+ cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
1249
+ cuk = torch.arange(0, (N + 1) * L, step=L, device=Q.device, dtype=torch.int32)
1250
+ if use_timing:
1251
+ t0 = time.perf_counter()
1252
+ o_pack = attention_fa2_varlen(
1253
+ q_pack,
1254
+ k_pack,
1255
+ v_pack,
1256
+ cuq,
1257
+ cuk,
1258
+ max_seqlen_q=1,
1259
+ max_seqlen_k=L,
1260
+ causal=False,
1261
+ ) # [N,h,Dv]
1262
+ if use_timing:
1263
+ dt = (time.perf_counter() - t0) * 1e3
1264
+ log("fa2.cmp.bucket", path="varlen", L=L, N=int(N), ms=dt)
1265
+ Ob = o_pack
1266
+ else:
1267
+ q_rows = Qb.unsqueeze(1)
1268
+ k_rows = Kb.unsqueeze(2).expand(N, L, h, Dk)
1269
+ v_rows = Vb.unsqueeze(2).expand(N, L, h, Dv)
1270
+ if use_timing:
1271
+ t0 = time.perf_counter()
1272
+ Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=True).squeeze(1)
1273
+ if not torch.isfinite(Ob).all():
1274
+ return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
1275
+ if use_timing:
1276
+ dt = (time.perf_counter() - t0) * 1e3
1277
+ log("fa2.cmp.bucket", path="dense", L=L, N=int(N), ms=dt)
1278
+ for i, (b, t, g) in enumerate(tgt):
1279
+ out[b, t, g] = Ob[i]
1280
+ return out
1281
+ except Exception as e:
1282
+ log("warn.fa2_unexpected_fallback", branch="cmp", error=str(e)[:100])
1283
+ return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
1284
+
1285
+
1286
+ def sliding_window_attention_fa2_decode(
1287
+ q_t: torch.Tensor, K_win: torch.Tensor, V_win: torch.Tensor, w: int
1288
+ ) -> torch.Tensor:
1289
+ B, G, h, Dk = q_t.shape
1290
+ # Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
1291
+ if _is_sm89(q_t.device) and not _fa2_forced():
1292
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
1293
+ log(
1294
+ "fa2.gate_skip",
1295
+ branch="win.decode",
1296
+ reason="sm89_guard",
1297
+ forced=bool(_fa2_forced()),
1298
+ )
1299
+ end = K_win.shape[2]
1300
+ win_len = min(w, end)
1301
+ if win_len == 0:
1302
+ return torch.zeros((B, G, h, V_win.shape[-1]), dtype=V_win.dtype, device=V_win.device)
1303
+ start = end - win_len
1304
+ return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True)
1305
+ end = K_win.shape[2]
1306
+ win_len = min(w, end)
1307
+ if win_len == 0:
1308
+ return torch.zeros((B, G, h, V_win.shape[-1]), dtype=V_win.dtype, device=V_win.device)
1309
+ # CPU or unsupported: direct SDPA for parity
1310
+ ok, why = fa2_supported_verbose(q_t.device, q_t.dtype, Dk)
1311
+ if not ok:
1312
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
1313
+ log("fa2.gate_skip", branch="win.decode", reason=why)
1314
+ start = end - win_len
1315
+ return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True)
1316
+ # Small-length auto-switch for decode
1317
+ try:
1318
+ min_len = int(os.getenv("NSA_FA2_MIN_LEN_WIN", "16"))
1319
+ except Exception:
1320
+ min_len = 16
1321
+ if min_len < 1:
1322
+ min_len = 1
1323
+ if win_len < min_len:
1324
+ if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
1325
+ log(
1326
+ "fa2.gate_skip",
1327
+ branch="win.decode",
1328
+ reason="below_min_len",
1329
+ win_len=int(win_len),
1330
+ min_len=int(min_len),
1331
+ )
1332
+ start = end - win_len
1333
+ return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True)
1334
+ start = end - win_len
1335
+ k = K_win[:, :, start:end]
1336
+ v = V_win[:, :, start:end]
1337
+ N = B * G
1338
+ q_rows = q_t.reshape(N, h, Dk).unsqueeze(1) # [N,1,h,Dk]
1339
+ k_rows = k.reshape(N, win_len, Dk).unsqueeze(2).expand(N, win_len, h, Dk)
1340
+ v_rows = v.reshape(N, win_len, v.shape[-1]).unsqueeze(2).expand(N, win_len, h, v.shape[-1])
1341
+ try:
1342
+ o = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False) # [N,1,h,Dv]
1343
+ o = o.squeeze(1).reshape(B, G, h, -1)
1344
+ if not torch.isfinite(o).all():
1345
+ return attention_bgh(q_t, k, v, causal=True)
1346
+ return o
1347
+ except Exception as e:
1348
+ log("warn.fa2_unexpected_fallback", branch="win.decode", error=str(e)[:100])
1349
+ return attention_bgh(q_t, k, v, causal=True)
1350
+
1351
+
1352
+ def compressed_attention_fa2_decode(
1353
+ q_t: torch.Tensor, K_cmp: torch.Tensor, V_cmp: torch.Tensor, L: int
1354
+ ) -> torch.Tensor:
1355
+ if L <= 0:
1356
+ B, G, h, _ = q_t.shape
1357
+ return torch.zeros((B, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
1358
+ B, G, h, Dk = q_t.shape
1359
+ # Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
1360
+ if _is_sm89(q_t.device) and not _fa2_forced():
1361
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
1362
+ log(
1363
+ "fa2.gate_skip",
1364
+ branch="cmp.decode",
1365
+ reason="sm89_guard",
1366
+ forced=bool(_fa2_forced()),
1367
+ )
1368
+ return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True)
1369
+ ok, why = fa2_supported_verbose(q_t.device, q_t.dtype, Dk)
1370
+ if not ok:
1371
+ if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
1372
+ log("fa2.gate_skip", branch="cmp.decode", reason=why)
1373
+ return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True)
1374
+ try:
1375
+ min_len = int(os.getenv("NSA_FA2_MIN_LEN_CMP", "16"))
1376
+ except Exception:
1377
+ min_len = 16
1378
+ if min_len < 1:
1379
+ min_len = 1
1380
+ if L < min_len:
1381
+ if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
1382
+ log(
1383
+ "fa2.gate_skip",
1384
+ branch="cmp.decode",
1385
+ reason="below_min_len",
1386
+ L=int(L),
1387
+ min_len=int(min_len),
1388
+ )
1389
+ return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True)
1390
+ k = K_cmp[:, :, :L]
1391
+ v = V_cmp[:, :, :L]
1392
+ N = B * G
1393
+ q_rows = q_t.reshape(N, h, Dk).unsqueeze(1)
1394
+ k_rows = k.reshape(N, L, Dk).unsqueeze(2).expand(N, L, h, Dk)
1395
+ v_rows = v.reshape(N, L, v.shape[-1]).unsqueeze(2).expand(N, L, h, v.shape[-1])
1396
+ try:
1397
+ o = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False)
1398
+ o = o.squeeze(1).reshape(B, G, h, -1)
1399
+ if not torch.isfinite(o).all():
1400
+ return attention_bgh(q_t, k, v, causal=True)
1401
+ return o
1402
+ except Exception:
1403
+ return attention_bgh(q_t, k, v, causal=True)
nsa/core/block_index.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+
7
+
8
+ @dataclass
9
+ class BlockMeta:
10
+ l: int
11
+ d: int
12
+ l_sel: int
13
+ n_sel: int
14
+ w: int
15
+ cmp_starts: torch.Tensor # [S_cmp]
16
+ sel_starts: torch.Tensor # [S_sel]
17
+ # CSR representation: (indptr, indices, values) mapping cmp_idx -> {sel_idx: weight}
18
+ M_csl_indptr: torch.Tensor
19
+ M_csl_indices: torch.Tensor
20
+ M_csl_values: torch.Tensor
21
+ # COO representation for fast batched matmul
22
+ M_csl_coo_indices: torch.Tensor # [2, nnz] rows, cols
23
+ M_csl_coo_values: torch.Tensor # [nnz]
24
+
25
+
26
+ def build_block_starts(
27
+ seq_len: int, l: int, d: int, l_sel: int
28
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
29
+ if d <= 0 or l <= 0 or l_sel <= 0:
30
+ raise ValueError("Block parameters must be positive")
31
+ # compression blocks (overlapped)
32
+ max_cmp = 0 if seq_len < l else (seq_len - l) // d + 1
33
+ cmp_starts = torch.arange(max_cmp, dtype=torch.int32) * d
34
+ # selection blocks (non-overlapped)
35
+ max_sel = 0 if seq_len <= 0 else (seq_len + l_sel - 1) // l_sel
36
+ sel_starts = torch.arange(max_sel, dtype=torch.int32) * l_sel
37
+ return cmp_starts, sel_starts
38
+
39
+
40
+ def _overlap_len(a0: int, a1: int, b0: int, b1: int) -> int:
41
+ return max(0, min(a1, b1) - max(a0, b0))
42
+
43
+
44
+ def build_M_csl_csr(
45
+ seq_len: int, l: int, d: int, l_sel: int
46
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
47
+ # Build CSR with fractional-overlap weights from cmp blocks to sel blocks
48
+ cmp_starts, sel_starts = build_block_starts(seq_len, l, d, l_sel)
49
+ indptr = [0]
50
+ indices: List[int] = []
51
+ values: List[float] = []
52
+ for cmp_i, s in enumerate(cmp_starts.tolist()):
53
+ a0, a1 = s, s + l
54
+ total = 0
55
+ row_pairs: List[Tuple[int, int]] = []
56
+ for sel_j, t in enumerate(sel_starts.tolist()):
57
+ b0, b1 = t, t + l_sel
58
+ ov = _overlap_len(a0, a1, b0, b1)
59
+ if ov > 0:
60
+ row_pairs.append((sel_j, ov))
61
+ total += ov
62
+ # normalize by total overlap to get fractional weights
63
+ if total > 0:
64
+ for sel_j, ov in row_pairs:
65
+ indices.append(sel_j)
66
+ values.append(ov / total)
67
+ indptr.append(len(indices))
68
+ return (
69
+ torch.tensor(indptr, dtype=torch.int32),
70
+ torch.tensor(indices, dtype=torch.int32),
71
+ torch.tensor(values, dtype=torch.float32),
72
+ )
73
+
74
+
75
+ def build_block_meta(seq_len: int, l: int, d: int, l_sel: int, n_sel: int, w: int) -> BlockMeta:
76
+ if l % d != 0 or l_sel % d != 0:
77
+ # Enforce divisibility by default (per PRD); general overlaps allowed later if needed
78
+ raise ValueError("Require d|l and d|l_sel in M0")
79
+ cmp_starts, sel_starts = build_block_starts(seq_len, l, d, l_sel)
80
+ indptr, indices, values = build_M_csl_csr(seq_len, l, d, l_sel)
81
+ # Build COO from CSR
82
+ rows: List[int] = []
83
+ for r in range(len(cmp_starts)):
84
+ start, end = int(indptr[r].item()), int(indptr[r + 1].item())
85
+ rows.extend([r] * (end - start))
86
+ coo_indices = torch.stack([torch.tensor(rows, dtype=torch.int32), indices.clone()], dim=0)
87
+ return BlockMeta(
88
+ l=l,
89
+ d=d,
90
+ l_sel=l_sel,
91
+ n_sel=n_sel,
92
+ w=w,
93
+ cmp_starts=cmp_starts,
94
+ sel_starts=sel_starts,
95
+ M_csl_indptr=indptr,
96
+ M_csl_indices=indices,
97
+ M_csl_values=values,
98
+ M_csl_coo_indices=coo_indices,
99
+ M_csl_coo_values=values.clone(),
100
+ )
nsa/core/collate.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+
6
+
7
+ def collate_token_batch(
8
+ sequences: List[List[int]],
9
+ *,
10
+ pad_id: int = 0,
11
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
12
+ """
13
+ Collate token id sequences (var-length) into padded tensors and masks with label shift.
14
+
15
+ Args:
16
+ sequences: list of token id lists
17
+ pad_id: id used for padding
18
+ Returns:
19
+ input_ids: [B,S_max]
20
+ labels: [B,S_max] (next-token labels; last position masked out)
21
+ attn_mask: [B,S_max] (True for valid tokens)
22
+ loss_mask: [B,S_max] (True for positions to include in loss)
23
+ lengths: [B]
24
+ cu_seqlens:[B+1] cumulative lengths
25
+ """
26
+ B = len(sequences)
27
+ lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.int32)
28
+ S_max = int(lengths.max().item()) if B > 0 else 0
29
+ input_ids = torch.full((B, S_max), pad_id, dtype=torch.long)
30
+ labels = torch.full((B, S_max), pad_id, dtype=torch.long)
31
+ attn_mask = torch.zeros((B, S_max), dtype=torch.bool)
32
+ loss_mask = torch.zeros((B, S_max), dtype=torch.bool)
33
+ for b, seq in enumerate(sequences):
34
+ L = len(seq)
35
+ if L == 0:
36
+ continue
37
+ input_ids[b, :L] = torch.tensor(seq, dtype=torch.long)
38
+ attn_mask[b, :L] = True
39
+ # next-token labels (shifted left by 1), last token has no next label
40
+ labels[b, : L - 1] = input_ids[b, 1:L]
41
+ loss_mask[b, : L - 1] = True
42
+ # cu_seqlens for varlen APIs
43
+ cu = torch.zeros((B + 1,), dtype=torch.int32)
44
+ cu[1:] = torch.cumsum(lengths, dim=0)
45
+ return input_ids, labels, attn_mask, loss_mask, lengths, cu
nsa/core/compress_pool.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from .rope import apply_rope
8
+
9
+
10
+ def avg_pool_phi_rope_kv(
11
+ K_raw: torch.Tensor,
12
+ V_raw: torch.Tensor,
13
+ l: int,
14
+ d: int,
15
+ pos: Optional[torch.Tensor] = None,
16
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
17
+ # Apply RoPE to K before ϕ; use absolute positions if provided
18
+ S = K_raw.shape[2]
19
+ if pos is None:
20
+ pos = torch.arange(S, device=K_raw.device)
21
+ K_rope = apply_rope(K_raw, pos)
22
+ V_rope = V_raw
23
+ # Expect shapes [B,G,S,D*]
24
+ B, G, S, Dk = K_rope.shape
25
+ # If sequence shorter than kernel, no compressed tokens yet
26
+ if S < l:
27
+ return (
28
+ torch.zeros((B, G, 0, Dk), device=K_rope.device, dtype=K_rope.dtype),
29
+ torch.zeros((B, G, 0, V_rope.shape[-1]), device=V_rope.device, dtype=V_rope.dtype),
30
+ )
31
+ # Unfold over time with stride d and kernel l (causal pooling over past)
32
+ Kf = K_rope.reshape(B * G, S, Dk).transpose(1, 2).unsqueeze(3) # [B*G, Dk, S, 1]
33
+ Vf = V_rope.reshape(B * G, S, -1).transpose(1, 2).unsqueeze(3)
34
+ Kp = F.avg_pool2d(Kf[:, :, :S, :], kernel_size=(l, 1), stride=(d, 1)) # [B*G, Dk, S_cmp, 1]
35
+ Vp = F.avg_pool2d(Vf[:, :, :S, :], kernel_size=(l, 1), stride=(d, 1))
36
+ S_cmp = Kp.shape[2]
37
+ K_cmp = Kp.squeeze(3).transpose(1, 2).reshape(B, G, S_cmp, Dk)
38
+ V_cmp = Vp.squeeze(3).transpose(1, 2).reshape(B, G, S_cmp, V_rope.shape[-1])
39
+ return K_cmp, V_cmp
nsa/core/debug.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ from typing import Any, Dict
4
+
5
+
6
+ def _flag(name: str) -> bool:
7
+ val = os.getenv(name, "0").lower()
8
+ return val in ("1", "true", "yes")
9
+
10
+
11
+ def debug_enabled() -> bool:
12
+ return _flag("NSA_DEBUG_LOG")
13
+
14
+
15
+ _COUNTS: Dict[str, int] = {}
16
+
17
+
18
+ def log(tag: str, **fields: Any) -> None:
19
+ if not debug_enabled():
20
+ return
21
+ limit_env = os.getenv("NSA_LOG_LIMIT")
22
+ if limit_env is not None:
23
+ try:
24
+ limit = int(limit_env)
25
+ except Exception:
26
+ limit = 0
27
+ if limit > 0:
28
+ cnt = _COUNTS.get(tag, 0)
29
+ if cnt >= limit:
30
+ return
31
+ _COUNTS[tag] = cnt + 1
32
+ parts = [f"{k}={_safe(v)}" for k, v in fields.items()]
33
+ print(f"NSA-LOG {tag} " + " ".join(parts))
34
+
35
+
36
+ def _safe(v: Any) -> str:
37
+ try:
38
+ if isinstance(v, int | float | str):
39
+ return str(v)
40
+ if hasattr(v, "shape"):
41
+ return str(tuple(int(x) for x in v.shape))
42
+ return str(v)
43
+ except Exception:
44
+ return "<unrepr>"
nsa/core/flags.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def env_true(name: str, default: bool = False) -> bool:
9
+ v = os.getenv(name)
10
+ if v is None:
11
+ return default
12
+ v = v.strip().lower()
13
+ return v in ("1", "true", "yes", "on")
14
+
15
+
16
+ def env_int(name: str, default: int) -> int:
17
+ try:
18
+ return int(os.getenv(name, str(default)))
19
+ except Exception:
20
+ return default
21
+
22
+
23
+ def is_sm89(device: Optional[torch.device] = None) -> bool:
24
+ dev = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
25
+ if dev.type != "cuda":
26
+ return False
27
+ try:
28
+ cap = torch.cuda.get_device_capability(dev)
29
+ return cap == (8, 9)
30
+ except Exception:
31
+ return False
32
+
33
+
34
+ def torch_triton_version_pairing_ok() -> bool:
35
+ try:
36
+ import triton # noqa: F401
37
+
38
+ tv = triton.__version__
39
+ except ImportError:
40
+ tv = "<none>"
41
+ except Exception:
42
+ tv = "<unknown>"
43
+ try:
44
+ tt = torch.__version__
45
+ except Exception:
46
+ tt = "<unknown>"
47
+ # Basic heuristic: 2.2.x ↔ triton 2.2.x; 2.3.x ↔ 2.3.x; 2.4+ ↔ 3.x
48
+ try:
49
+ major_minor = ".".join((tt or "").split("+")[0].split(".")[:2])
50
+ parts = major_minor.split(".")
51
+ t_major = int(parts[0])
52
+ t_minor = int(parts[1])
53
+ if t_major != 2:
54
+ return True # do not gate non-2.x
55
+ if t_minor in (2, 3):
56
+ return tv.startswith(f"{t_minor}.")
57
+ if t_minor >= 4:
58
+ return tv.startswith("3.")
59
+ return True
60
+ except (ValueError, IndexError):
61
+ return True
62
+
63
+
64
+ def execution_routing_summary() -> dict:
65
+ """Return a snapshot of routing-related flags and runtime probes."""
66
+ info = {
67
+ "cuda": torch.cuda.is_available(),
68
+ "sm89": is_sm89(),
69
+ "torch": torch.__version__,
70
+ }
71
+ try:
72
+ import triton
73
+
74
+ info["triton"] = triton.__version__
75
+ except Exception:
76
+ info["triton"] = "<none>"
77
+ info["NSA_USE_TRITON_SEL"] = env_true("NSA_USE_TRITON_SEL", False)
78
+ info["NSA_TRITON_SEL_FORCE"] = env_true("NSA_TRITON_SEL_FORCE", False)
79
+ info["NSA_USE_FA2"] = env_true("NSA_USE_FA2", False)
80
+ return info
nsa/core/nsa_attention.py ADDED
@@ -0,0 +1,1850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ from typing import List, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from nsa.cache.kv_cache import NSA_KV
10
+ from nsa.core.attention_kernels import (
11
+ compressed_attention_fa2,
12
+ compressed_attention_fa2_decode,
13
+ grouped_selection_attention,
14
+ grouped_selection_attention_masked,
15
+ grouped_selection_attention_packed,
16
+ sliding_window_attention_fa2,
17
+ sliding_window_attention_fa2_decode,
18
+ )
19
+ from nsa.core.block_index import build_block_meta
20
+ from nsa.core.compress_pool import avg_pool_phi_rope_kv
21
+ from nsa.core.debug import log
22
+ from nsa.core.rope import apply_rope
23
+ from nsa.core.selection_scorer import (
24
+ compute_pcmp_all,
25
+ map_pcmp_to_pslc_batched,
26
+ select_topn_ranges,
27
+ select_topn_ranges_batched,
28
+ verify_mapping_equivalence,
29
+ )
30
+ from nsa.kernels.flash_wrappers import attention_bgh
31
+
32
+
33
+ class GateMLP(nn.Module):
34
+ def __init__(self, d_k: int, hidden: Optional[int] = None):
35
+ super().__init__()
36
+ hidden = hidden or max(1, d_k // 2)
37
+ self.fc1 = nn.Linear(d_k, hidden)
38
+ self.fc2 = nn.Linear(hidden, 3)
39
+ # Initialize fc2 with small random values to break symmetry and enable learning
40
+ # Use Xavier uniform with reduced scale to start near uniform but allow differentiation
41
+ nn.init.xavier_uniform_(self.fc2.weight, gain=0.1)
42
+ nn.init.zeros_(self.fc2.bias) # Keep bias at zero for initial balance
43
+ # Cache environment variables at init to avoid hot path parsing
44
+ self._force_uniform_gate = os.getenv("NSA_FORCE_UNIFORM_GATE", "0").lower() in (
45
+ "1",
46
+ "true",
47
+ "yes",
48
+ )
49
+ self._force_branch = os.getenv("NSA_FORCE_BRANCH")
50
+
51
+ def forward(self, q_group_pooled: torch.Tensor, tau: float = 1.0) -> torch.Tensor:
52
+ # Uniform gate override for debugging DDP hangs
53
+ if self._force_uniform_gate:
54
+ one_third = 1.0 / 3.0
55
+ shape = (*q_group_pooled.shape[:-1], 3)
56
+ return torch.full(
57
+ shape, one_third, device=q_group_pooled.device, dtype=q_group_pooled.dtype
58
+ )
59
+ fb = self._force_branch
60
+ if fb:
61
+ fb = fb.strip().lower()
62
+ if fb in ("cmp", "sel", "win"):
63
+ idx = 0 if fb == "cmp" else (1 if fb == "sel" else 2)
64
+ one = torch.zeros(
65
+ (*q_group_pooled.shape[:-1], 3),
66
+ device=q_group_pooled.device,
67
+ dtype=q_group_pooled.dtype,
68
+ )
69
+ one[..., idx] = 1.0
70
+ return one
71
+ x = F.silu(self.fc1(q_group_pooled))
72
+ g = self.fc2(x) / max(tau, 1e-6)
73
+ p = F.softmax(g, dim=-1)
74
+ # Hard one-hot if extremely peaked to avoid numerical drift in ablations/tests
75
+ with torch.no_grad():
76
+ top2 = torch.topk(g, k=2, dim=-1).values
77
+ peaked = (top2[..., 0] - top2[..., 1]) > 50.0
78
+ if peaked.any():
79
+ one_hot = torch.zeros_like(p)
80
+ idx = torch.argmax(g, dim=-1, keepdim=True)
81
+ one_hot.scatter_(-1, idx, 1.0)
82
+ p = torch.where(peaked.unsqueeze(-1), one_hot, p)
83
+ return p
84
+
85
+
86
+ def _fused_gate_combine_bsg(
87
+ q_gp: torch.Tensor, # [B,S,G,Dk]
88
+ O_cmp: torch.Tensor, # [B,S,G,h,Dv]
89
+ O_sel: torch.Tensor, # [B,S,G,h,Dv]
90
+ O_win: torch.Tensor, # [B,S,G,h,Dv]
91
+ fc1_w: torch.Tensor,
92
+ fc1_b: torch.Tensor | None,
93
+ fc2_w: torch.Tensor,
94
+ fc2_b: torch.Tensor | None,
95
+ tau: float,
96
+ ) -> torch.Tensor:
97
+ import torch.nn.functional as _F
98
+ x = _F.silu(_F.linear(q_gp, fc1_w, fc1_b))
99
+ g = _F.linear(x, fc2_w, fc2_b) / max(tau, 1e-6)
100
+ p = _F.softmax(g, dim=-1)
101
+ w_cmp = p[..., 0:1].unsqueeze(-1)
102
+ w_sel = p[..., 1:2].unsqueeze(-1)
103
+ w_win = p[..., 2:3].unsqueeze(-1)
104
+ return w_cmp * O_cmp + w_sel * O_sel + w_win * O_win
105
+
106
+
107
+ def _fused_gate_combine_bg(
108
+ q_gp: torch.Tensor, # [B,G,Dk]
109
+ O_cmp: torch.Tensor, # [B,G,h,Dv]
110
+ O_sel: torch.Tensor, # [B,G,h,Dv]
111
+ O_win: torch.Tensor, # [B,G,h,Dv]
112
+ fc1_w: torch.Tensor,
113
+ fc1_b: torch.Tensor | None,
114
+ fc2_w: torch.Tensor,
115
+ fc2_b: torch.Tensor | None,
116
+ tau: float,
117
+ ) -> torch.Tensor:
118
+ import torch.nn.functional as _F
119
+ x = _F.silu(_F.linear(q_gp, fc1_w, fc1_b))
120
+ g = _F.linear(x, fc2_w, fc2_b) / max(tau, 1e-6)
121
+ p = _F.softmax(g, dim=-1)
122
+ w_cmp = p[..., 0:1].unsqueeze(-1)
123
+ w_sel = p[..., 1:2].unsqueeze(-1)
124
+ w_win = p[..., 2:3].unsqueeze(-1)
125
+ return w_cmp * O_cmp + w_sel * O_sel + w_win * O_win
126
+
127
+
128
+ def _compute_gate_stats(gates: torch.Tensor) -> dict:
129
+ """Compute gate health statistics for monitoring.
130
+
131
+ Args:
132
+ gates: Gate probabilities [B, S, G, 3] or [B, G, 3]
133
+
134
+ Returns:
135
+ Dict with gate statistics: entropy, max_gate, branch_shares
136
+ """
137
+ with torch.no_grad():
138
+ # Flatten to [*, 3] for consistent computation
139
+ gates_flat = gates.view(-1, 3)
140
+
141
+ # Gate entropy (should be > 0.5 for healthy mixing)
142
+ entropy = -(gates_flat * (gates_flat + 1e-8).log()).sum(dim=-1)
143
+ mean_entropy = entropy.mean().item()
144
+ min_entropy = entropy.min().item()
145
+
146
+ # Max gate value (should be < 0.9 to avoid collapse)
147
+ max_gate = gates_flat.max(dim=-1)[0]
148
+ mean_max_gate = max_gate.mean().item()
149
+ max_max_gate = max_gate.max().item()
150
+
151
+ # Branch usage shares (should be balanced)
152
+ branch_shares = gates_flat.mean(dim=0).tolist() # [cmp, sel, win]
153
+
154
+ # Gate collapse detection (entropy < 0.1 and max_gate > 0.95)
155
+ collapsed = (entropy < 0.1) & (max_gate > 0.95)
156
+ collapse_fraction = collapsed.float().mean().item()
157
+
158
+ return {
159
+ "entropy_mean": mean_entropy,
160
+ "entropy_min": min_entropy,
161
+ "max_gate_mean": mean_max_gate,
162
+ "max_gate_max": max_max_gate,
163
+ "branch_shares": branch_shares, # [cmp, sel, win]
164
+ "collapse_fraction": collapse_fraction,
165
+ "total_gates": len(gates_flat),
166
+ }
167
+
168
+
169
+ class NSAAttention(nn.Module):
170
+ """
171
+ Native Sparse Attention (NSA) module (M0 steel-thread).
172
+
173
+ Shapes:
174
+ - Input x (prefill): [B,S,dim]; x (decode): [B,1,dim]
175
+ - Heads: n_heads, grouped into n_kv_groups with h_per_group = n_heads // n_kv_groups
176
+ - Projections produce:
177
+ - Q: [B,S,G,h,Dk]
178
+ - K/V per-branch: [B,G,S,D*]
179
+
180
+ Returns:
181
+ - out: [B,S,dim] (prefill) or [B,1,dim] (decode)
182
+ - kv: updated NSA_KV caches
183
+
184
+ Notes:
185
+ - M0 constraints: SDPA-only, fixed sequence length in tests, deterministic.
186
+ - Masked/packed fast paths are env-gated with `NSA_FORCE_PARITY` fallback.
187
+ """
188
+
189
+ def __init__(
190
+ self,
191
+ dim: int,
192
+ n_heads: int,
193
+ n_kv_groups: int,
194
+ d_k: int,
195
+ d_v: int,
196
+ l: int = 32,
197
+ d: int = 16,
198
+ l_sel: int = 64,
199
+ n_sel: int = 16,
200
+ w: int = 512,
201
+ phi: str = "avg",
202
+ gate_hidden: Optional[int] = None,
203
+ gate_temp: float = 1.0,
204
+ rope_impl: str = "llama",
205
+ use_flash: bool = True,
206
+ use_triton_sel: bool = False,
207
+ ) -> None:
208
+ super().__init__()
209
+ assert n_heads % n_kv_groups == 0, "heads must be divisible by kv groups"
210
+ # M0 config validation (PRD enforces divisibility)
211
+ if l % d != 0 or l_sel % d != 0:
212
+ raise ValueError("M0 requires d|l and d|l_sel; set valid block sizes/stride.")
213
+ self.dim = dim
214
+ self.n_heads = n_heads
215
+ self.n_kv_groups = n_kv_groups
216
+ self.h_per_group = n_heads // n_kv_groups
217
+ self.d_k = d_k
218
+ self.d_v = d_v
219
+ self.l = l
220
+ self.d = d
221
+ self.l_sel = l_sel
222
+ self.n_sel = n_sel
223
+ self.w = w
224
+ self.gate_temp = gate_temp
225
+ self.phi_type = (phi or "avg").lower()
226
+
227
+ # Gate health tracking for M8 monitoring
228
+ self._last_gate_stats = None
229
+ # M8: Selection length stats for monitoring (updated each forward)
230
+ self._last_sel_stats: Optional[dict] = None
231
+
232
+ # M8: Fallback counters for routing monitoring
233
+ self._fallback_counters = {
234
+ "selection_triton_fails": 0,
235
+ "selection_cuda_fails": 0,
236
+ "selection_pack_fails": 0,
237
+ "selection_mask_fails": 0,
238
+ "compressed_fa2_fails": 0,
239
+ "sliding_fa2_fails": 0,
240
+ "total_fallbacks": 0,
241
+ }
242
+
243
+ # RoPE scaling and prefill tiling for long-context demos (env-overridable)
244
+ try:
245
+ rs = float(os.getenv("NSA_ROPE_SCALE", "1.0"))
246
+ if not (rs > 0.0) or rs != rs: # require positive finite
247
+ rs = 1.0
248
+ self.rope_scale = rs
249
+ except ValueError:
250
+ self.rope_scale = 1.0
251
+ try:
252
+ pt = int(os.getenv("NSA_PREFILL_TILE", "0"))
253
+ if pt < 0:
254
+ pt = 0
255
+ self.prefill_tile = pt
256
+ except ValueError:
257
+ self.prefill_tile = 0
258
+ # Projections
259
+ self.W_Q = nn.Linear(dim, n_heads * d_k, bias=False)
260
+ self.W_K_sel = nn.Linear(dim, n_kv_groups * d_k, bias=False)
261
+ self.W_V_sel = nn.Linear(dim, n_kv_groups * d_v, bias=False)
262
+ self.W_K_win = nn.Linear(dim, n_kv_groups * d_k, bias=False)
263
+ self.W_V_win = nn.Linear(dim, n_kv_groups * d_v, bias=False)
264
+ self.W_K_cmp = nn.Linear(dim, n_kv_groups * d_k, bias=False)
265
+ self.W_V_cmp = nn.Linear(dim, n_kv_groups * d_v, bias=False)
266
+ self.out = nn.Linear(n_heads * d_v, dim, bias=False)
267
+ self.gate = GateMLP(d_k, gate_hidden)
268
+ # Default FA-2 usage (can be overridden by env flags)
269
+ self.use_flash_default = use_flash
270
+ # One-time SDPA backend audit flag
271
+ self._sdpa_audited = False
272
+ # Selection Triton toggle (M4)
273
+ self.use_triton_sel = use_triton_sel
274
+ # Cache environment variables to avoid repeated parsing in hot path
275
+ self._cache_env_vars()
276
+ # Optional learnable ϕ via depthwise Conv1d over time with kernel l and stride d
277
+ # Initialize to average pooling for parity with M0
278
+ self.phi_k_conv: Optional[nn.Conv1d]
279
+ self.phi_v_conv: Optional[nn.Conv1d]
280
+ if self.phi_type == "mlp":
281
+ self.phi_k_conv = nn.Conv1d(
282
+ self.d_k, self.d_k, kernel_size=self.l, stride=self.d, groups=self.d_k, bias=False
283
+ )
284
+ self.phi_v_conv = nn.Conv1d(
285
+ self.d_v, self.d_v, kernel_size=self.l, stride=self.d, groups=self.d_v, bias=False
286
+ )
287
+ with torch.no_grad():
288
+ self.phi_k_conv.weight.fill_(1.0 / float(self.l))
289
+ self.phi_v_conv.weight.fill_(1.0 / float(self.l))
290
+ else:
291
+ self.phi_k_conv = None
292
+ self.phi_v_conv = None
293
+
294
+ def _cache_env_vars(self) -> None:
295
+ """Cache environment variables to avoid repeated parsing in hot path."""
296
+
297
+ def parse_bool(val: str, default: str = "0") -> bool:
298
+ return os.getenv(val, default).lower() in ("1", "true", "yes")
299
+
300
+ # Cache frequently accessed environment variables
301
+ # Raw parsed flags
302
+ self._env_cache = {
303
+ "static": parse_bool("NSA_ENV_STATIC", "0"),
304
+ "force_uniform_gate": parse_bool("NSA_FORCE_UNIFORM_GATE", "0"),
305
+ "force_branch": os.getenv("NSA_FORCE_BRANCH"),
306
+ "prefill_batched": parse_bool("NSA_PREFILL_BATCHED", "0"),
307
+ "strict_asserts": parse_bool("NSA_STRICT_ASSERTS", "0"),
308
+ "force_parity": parse_bool("NSA_FORCE_PARITY", "0"),
309
+ "use_sel_pack": parse_bool("NSA_USE_SEL_PACK", "1"),
310
+ "use_triton_sel": parse_bool("NSA_USE_TRITON_SEL", "0") or self.use_triton_sel,
311
+ "use_cuda_sel": parse_bool("NSA_SEL_CUDA", "0"),
312
+ "use_sel_varlen": parse_bool("NSA_USE_SEL_VARLEN", "0"),
313
+ # Hard override to force masked selection path (debug/triage)
314
+ "force_sel_mask": parse_bool("NSA_FORCE_SEL_MASK", "0"),
315
+ "fa2_all": parse_bool("NSA_USE_FA2", "0"),
316
+ "fa2_win": parse_bool("NSA_USE_FA2_WIN", "0"),
317
+ "fa2_cmp": parse_bool("NSA_USE_FA2_CMP", "0"),
318
+ "use_sel_mask": parse_bool("NSA_USE_SEL_MASK", "0"),
319
+ "use_cmp_mask": parse_bool("NSA_USE_CMP_MASK", "1"),
320
+ "use_win_mask": parse_bool("NSA_USE_WIN_MASK", "1"),
321
+ "verify_eq9": parse_bool("NSA_VERIFY_EQ9_MAPPING", "0"),
322
+ "stopgrad_gates": parse_bool("NSA_STOPGRAD_GATES", "0"),
323
+ "nvtx": parse_bool("NSA_NVTX", "0"),
324
+ "debug_compare": parse_bool("NSA_DEBUG_COMPARE", "0"),
325
+ "gate_compile": parse_bool("NSA_GATE_COMPILE", "0"),
326
+ }
327
+
328
+ # Detect whether env overrides were explicitly provided so we can honor hard-disable
329
+ fa2_all_set = "NSA_USE_FA2" in os.environ
330
+ fa2_win_set = "NSA_USE_FA2_WIN" in os.environ
331
+ fa2_cmp_set = "NSA_USE_FA2_CMP" in os.environ
332
+ self._env_cache.update(
333
+ {
334
+ "fa2_all_set": fa2_all_set,
335
+ "fa2_win_set": fa2_win_set,
336
+ "fa2_cmp_set": fa2_cmp_set,
337
+ }
338
+ )
339
+
340
+ # Compute effective FA-2 gating with sensible defaults and hard-disable semantics
341
+ fa2_all_env = self._env_cache["fa2_all"]
342
+ fa2_win_env = self._env_cache["fa2_win"]
343
+ fa2_cmp_env = self._env_cache["fa2_cmp"]
344
+
345
+ # Defaults when no explicit env flags are provided:
346
+ # - Enable compressed FA‑2 by default (robustly capability-gated at call sites)
347
+ # - Keep sliding FA‑2 off by default due to API semantics
348
+ # - Do not use the global "all" default to avoid inadvertently enabling sliding
349
+ if not (fa2_all_set or fa2_win_set or fa2_cmp_set):
350
+ fa2_all_eff = False
351
+ fa2_win_eff = False
352
+ fa2_cmp_eff = True
353
+ else:
354
+ # If NSA_USE_FA2 not set, fall back to model default; else honor explicit value
355
+ fa2_all_eff = self.use_flash_default if not fa2_all_set else fa2_all_env
356
+
357
+ # If global is explicitly set to 0, that hard-disables branch flags too
358
+ if fa2_all_set and not fa2_all_env:
359
+ fa2_win_eff = False
360
+ fa2_cmp_eff = False
361
+ else:
362
+ # Branch-specific flags only take effect if explicitly set; otherwise default off
363
+ fa2_win_eff = fa2_win_env if fa2_win_set else False
364
+ fa2_cmp_eff = fa2_cmp_env if fa2_cmp_set else False
365
+
366
+ self._env_cache.update(
367
+ {
368
+ "fa2_all_eff": fa2_all_eff,
369
+ "fa2_win_eff": fa2_win_eff,
370
+ "fa2_cmp_eff": fa2_cmp_eff,
371
+ }
372
+ )
373
+ # Parse numeric values
374
+ try:
375
+ self._rope_scale = float(os.getenv("NSA_ROPE_SCALE", "1.0"))
376
+ if not (self._rope_scale > 0.0) or self._rope_scale != self._rope_scale:
377
+ self._rope_scale = 1.0
378
+ except (ValueError, TypeError):
379
+ self._rope_scale = 1.0
380
+
381
+ try:
382
+ self._prefill_tile = int(os.getenv("NSA_PREFILL_TILE", "0"))
383
+ if self._prefill_tile < 0:
384
+ self._prefill_tile = 0
385
+ except (ValueError, TypeError):
386
+ self._prefill_tile = 0
387
+ # Fused gate combine (lazy-compiled)
388
+ self._gate_fused_bsg = None
389
+ self._gate_fused_bg = None
390
+
391
+ def _shape_q(self, Q: torch.Tensor, B: int, S: int) -> torch.Tensor:
392
+ Q = Q.view(B, S, self.n_heads, self.d_k)
393
+ # group-major: [B,S,G,h,Dk]
394
+ G = self.n_kv_groups
395
+ h = self.h_per_group
396
+ return Q.view(B, S, G, h, self.d_k)
397
+
398
+ def _shape_kv(self, X: torch.Tensor, B: int, S: int) -> torch.Tensor:
399
+ G = self.n_kv_groups
400
+ return X.view(B, S, G, -1).permute(0, 2, 1, 3).contiguous() # [B,G,S,D*]
401
+
402
+ def get_gate_stats(self) -> Optional[dict]:
403
+ """Get the most recent gate statistics for monitoring.
404
+
405
+ Returns:
406
+ Dict with gate health metrics or None if no recent computation
407
+ """
408
+ return self._last_gate_stats
409
+
410
+ def get_fallback_counters(self) -> dict:
411
+ """Get fallback counters for routing monitoring.
412
+
413
+ Returns:
414
+ Dict with fallback counts per implementation type
415
+ """
416
+ return self._fallback_counters.copy()
417
+
418
+ def get_selection_stats(self) -> Optional[dict]:
419
+ """Return last computed selection length statistics, if available.
420
+
421
+ Keys:
422
+ - k_mean: mean selected K per row (float)
423
+ - k_max: max selected K in batch (int)
424
+ - rows: number of (B,S,G) rows aggregated (int)
425
+ - pct_at_max: fraction of rows equal to k_max (float)
426
+ - l_sel: configured selection block size (int)
427
+ - n_sel: configured top-n selection blocks (int)
428
+ """
429
+ return self._last_sel_stats
430
+
431
+ def reset_fallback_counters(self) -> dict:
432
+ """Reset fallback counters and return the previous values.
433
+
434
+ Returns:
435
+ Dict with fallback counts before reset
436
+ """
437
+ prev_counters = self._fallback_counters.copy()
438
+ for key in self._fallback_counters:
439
+ self._fallback_counters[key] = 0
440
+ return prev_counters
441
+
442
+ def _update_gate_stats(self, gates: torch.Tensor) -> None:
443
+ """Update stored gate statistics for monitoring."""
444
+ try:
445
+ self._last_gate_stats = _compute_gate_stats(gates)
446
+ except Exception as e:
447
+ log("warn.gate_stats_fail", error=str(e))
448
+ self._last_gate_stats = None
449
+
450
+ def _update_sel_stats_from_ranges(self, ranges: torch.Tensor) -> None:
451
+ """Compute and store selection statistics from [B,*,G,n,2] ranges tensor."""
452
+ try:
453
+ if ranges is None or ranges.numel() == 0:
454
+ self._last_sel_stats = {
455
+ "k_mean": 0.0,
456
+ "k_max": 0,
457
+ "rows": 0,
458
+ "pct_at_max": 0.0,
459
+ "l_sel": int(self.l_sel),
460
+ "n_sel": int(self.n_sel),
461
+ }
462
+ return
463
+ # ranges: [B, T, G, n, 2] or [B, G, n, 2]
464
+ if ranges.dim() == 5:
465
+ B, T, G, n, _ = ranges.shape
466
+ rs = ranges
467
+ rows = B * T * G
468
+ # [B,T,G,n]
469
+ lengths = (rs[..., 1] - rs[..., 0]).clamp_min(0)
470
+ # Sum across n ranges → [B,T,G]
471
+ L = lengths.sum(dim=-1).to(torch.int64)
472
+ elif ranges.dim() == 4:
473
+ B, G, n, _ = ranges.shape
474
+ rs = ranges
475
+ rows = B * G
476
+ lengths = (rs[..., 1] - rs[..., 0]).clamp_min(0)
477
+ L = lengths.sum(dim=-1).to(torch.int64) # [B,G]
478
+ else:
479
+ # Unknown shape; skip
480
+ return
481
+ if L.numel() == 0:
482
+ k_mean = 0.0
483
+ k_max = 0
484
+ pct_at_max = 0.0
485
+ else:
486
+ k_max = int(L.max().item())
487
+ k_mean = float(L.to(torch.float32).mean().item())
488
+ if k_max > 0:
489
+ pct_at_max = float((L == k_max).to(torch.float32).mean().item())
490
+ else:
491
+ pct_at_max = 0.0
492
+ self._last_sel_stats = {
493
+ "k_mean": k_mean,
494
+ "k_max": k_max,
495
+ "rows": int(rows),
496
+ "pct_at_max": pct_at_max,
497
+ "l_sel": int(self.l_sel),
498
+ "n_sel": int(self.n_sel),
499
+ }
500
+ except Exception as e:
501
+ log("warn.sel_stats_fail", error=str(e))
502
+ self._last_sel_stats = None
503
+
504
+ def forward(self, x: torch.Tensor, kv: NSA_KV, *, prefill: bool) -> tuple[torch.Tensor, NSA_KV]:
505
+ """
506
+ Forward pass.
507
+
508
+ Args:
509
+ x: [B,S,dim] if prefill else [B,1,dim]
510
+ kv: NSA_KV caches (updated in-place per branch)
511
+ prefill: True for batched prefill, False for single-token decode
512
+
513
+ Returns:
514
+ (out, kv): out is [B,S,dim] (prefill) or [B,1,dim] (decode)
515
+ """
516
+ # x: [B,S,dim] (prefill) or [B,1,dim] (decode)
517
+ B, S, _ = x.shape
518
+ assert x.dim() == 3, "x must be [B,S,dim]"
519
+ assert self.n_heads % self.n_kv_groups == 0, "n_heads must be divisible by n_kv_groups"
520
+ # Strict assertions may introduce GPU syncs; gate via env for tests/smokes
521
+ strict_asserts = self._env_cache.get("strict_asserts", False)
522
+
523
+ # M8: Assert causal masking - enforce mode constraints
524
+ if prefill:
525
+ assert S > 0, f"Prefill mode requires S > 0, got S={S}"
526
+ else:
527
+ assert S == 1, (
528
+ f"Decode mode requires S=1 (single token), got S={S}. "
529
+ f"This ensures proper causal ordering in decode steps."
530
+ )
531
+ if prefill:
532
+ # Optional: route prefill via single-token decode steps to support very long contexts safely.
533
+ if getattr(self, "prefill_tile", 0) and self.prefill_tile > 0:
534
+ return self._forward_prefill_via_decode(x, kv)
535
+ use_batched = self._env_cache.get("prefill_batched", False)
536
+ if use_batched:
537
+ return self._forward_prefill_batched(x, kv)
538
+ else:
539
+ return self._forward_prefill_sequential(x, kv)
540
+ else:
541
+ # Projections
542
+ # Compute absolute position offset from existing cache length for RoPE on Q
543
+ t_prev = kv.K_sel.shape[2] if hasattr(kv, "K_sel") else 0
544
+ Q_lin = self._shape_q(self.W_Q(x), B, S) # [B,S,G,h,Dk]
545
+ # Apply RoPE to Q with absolute positions (decode)
546
+ pos = torch.arange(t_prev, t_prev + S, device=x.device)
547
+ Q = apply_rope(
548
+ Q_lin.view(B, S, self.n_heads, self.d_k).reshape(B, S, self.n_heads * self.d_k),
549
+ pos,
550
+ scale=getattr(self, "rope_scale", 1.0),
551
+ )
552
+ Q = Q.view(B, S, self.n_heads, self.d_k)
553
+ G = self.n_kv_groups
554
+ h = self.h_per_group
555
+ Q = Q.view(B, S, G, h, self.d_k)
556
+ K_sel = self._shape_kv(self.W_K_sel(x), B, S)
557
+ V_sel = self._shape_kv(self.W_V_sel(x), B, S)
558
+ K_win = self._shape_kv(self.W_K_win(x), B, S)
559
+ V_win = self._shape_kv(self.W_V_win(x), B, S)
560
+ K_cmp_raw = self._shape_kv(self.W_K_cmp(x), B, S)
561
+ V_cmp_raw = self._shape_kv(self.W_V_cmp(x), B, S)
562
+
563
+ # Apply RoPE to K for selection/sliding branches using absolute position of the new token(s)
564
+ # Determine current token index before appending to caches
565
+ t_prev = kv.K_sel.shape[2] if hasattr(kv, "K_sel") else 0
566
+ pos_k = torch.arange(t_prev, t_prev + S, device=x.device)
567
+ K_sel = apply_rope(K_sel, pos_k, scale=getattr(self, "rope_scale", 1.0))
568
+ K_win = apply_rope(K_win, pos_k, scale=getattr(self, "rope_scale", 1.0))
569
+
570
+ # decode step: append raw tokens and window, emit compressed every d after warmup l
571
+ kv.update_selection_raw(K_sel, V_sel)
572
+ kv.update_window(K_win, V_win, self.w)
573
+ if not hasattr(kv, "K_cmp_raw_seq"):
574
+ kv.K_cmp_raw_seq = K_cmp_raw[:, :, :0]
575
+ kv.V_cmp_raw_seq = V_cmp_raw[:, :, :0]
576
+ kv.reads_pred = torch.zeros((0,), dtype=torch.int64, device=x.device)
577
+ kv.reads_act_total = torch.zeros((0,), dtype=torch.int64, device=x.device)
578
+ kv.reads_act_sel = torch.zeros((0,), dtype=torch.int64, device=x.device)
579
+ kv.reads_act_cmp = torch.zeros((0,), dtype=torch.int64, device=x.device)
580
+ kv.reads_act_win = torch.zeros((0,), dtype=torch.int64, device=x.device)
581
+ kv.append_cmp_raw(K_cmp_raw, V_cmp_raw)
582
+ S_raw = kv.K_cmp_raw_seq.shape[2]
583
+ if S_raw >= self.l and (S_raw - self.l) % self.d == 0:
584
+ # Emit compressed token from the last l raw tokens
585
+ K_last = kv.K_cmp_raw_seq[:, :, S_raw - self.l : S_raw, :]
586
+ V_last = kv.V_cmp_raw_seq[:, :, S_raw - self.l : S_raw, :]
587
+ pos_last = torch.arange(S_raw - self.l, S_raw, device=x.device)
588
+ if self.phi_type == "mlp":
589
+ K_cmp_new, V_cmp_new = self._phi_apply_last(K_last, V_last, pos_last)
590
+ else:
591
+ K_cmp_new, V_cmp_new = avg_pool_phi_rope_kv(
592
+ K_last, V_last, self.l, self.d, pos=pos_last
593
+ )
594
+ kv.update_compressed(
595
+ torch.cat([kv.K_cmp, K_cmp_new], dim=2) if kv.K_cmp.numel() else K_cmp_new,
596
+ torch.cat([kv.V_cmp, V_cmp_new], dim=2) if kv.V_cmp.numel() else V_cmp_new,
597
+ self.l,
598
+ self.d,
599
+ )
600
+
601
+ # Ensure block metadata exists and covers current token index for selection (expand if needed)
602
+ t_token = kv.K_sel.shape[2] - 1
603
+ if not hasattr(kv, "meta") or kv.meta.sel_starts.numel() == 0:
604
+ kv.meta = build_block_meta(
605
+ seq_len=max(t_token + 1, self.l_sel),
606
+ l=self.l,
607
+ d=self.d,
608
+ l_sel=self.l_sel,
609
+ n_sel=self.n_sel,
610
+ w=self.w,
611
+ )
612
+ else:
613
+ # If current t exceeds covered selection range, rebuild meta with expanded seq_len
614
+ sel_max_end = (
615
+ int(kv.meta.sel_starts[-1].item()) + kv.meta.l_sel
616
+ if kv.meta.sel_starts.numel() > 0
617
+ else 0
618
+ )
619
+ if (t_token + 1) > sel_max_end:
620
+ kv.meta = build_block_meta(
621
+ seq_len=t_token + 1,
622
+ l=self.l,
623
+ d=self.d,
624
+ l_sel=self.l_sel,
625
+ n_sel=self.n_sel,
626
+ w=self.w,
627
+ )
628
+ # Append predicted reads per formula for this step
629
+ num_cmp = 0 if S_raw < self.l else (S_raw - self.l) // self.d + 1
630
+ reads = num_cmp + self.n_sel * self.l_sel + min(self.w, S_raw)
631
+ kv.append_reads_pred(reads)
632
+ # Append actual reads equal to formula in M0
633
+ kv.append_reads_actual(reads, self.n_sel * self.l_sel, num_cmp, min(self.w, S_raw))
634
+ log(
635
+ "decode.reads",
636
+ S_raw=int(S_raw),
637
+ num_cmp=int(num_cmp),
638
+ sel=int(self.n_sel * self.l_sel),
639
+ win=int(min(self.w, S_raw)),
640
+ total=int(reads),
641
+ )
642
+
643
+ scale = 1.0 / (self.d_k**0.5)
644
+ # Compute p_cmp only for this step (S is 1 in decode)
645
+ K_cmp_full = kv.K_cmp
646
+ p_cmp_all = compute_pcmp_all(Q, K_cmp_full, scale)
647
+ # Per-token outputs (S should be 1 in decode)
648
+ outs = []
649
+ # Use cached environment variables
650
+ env = self._env_cache
651
+
652
+ for t in range(S):
653
+ p_slc_all = map_pcmp_to_pslc_batched(p_cmp_all[:, t : t + 1], kv.meta)
654
+
655
+ # M8: Optional Eq.9 verification in decode
656
+ if self._env_cache.get("verify_eq9", False):
657
+ is_equiv, details = verify_mapping_equivalence(p_cmp_all[:, t : t + 1], kv.meta)
658
+ if not is_equiv:
659
+ log(
660
+ "error.eq9_verification_failed_decode",
661
+ msg="Eq.9 mapping verification failed in decode",
662
+ step=t,
663
+ **details,
664
+ )
665
+ p_grp = p_slc_all.sum(dim=3).squeeze(1) # [B,G,S_sel]
666
+ current_pos = kv.K_sel.shape[2] - 1 # Current token position (0-indexed)
667
+ sel_ranges = select_topn_ranges(p_grp, kv.meta, self.n_sel, current_pos, True, 2)
668
+
669
+ # M8: Assert causal masking - selection ranges cannot include future tokens
670
+ if strict_asserts and sel_ranges.numel() > 0:
671
+ # Only sync for strict asserts (debug mode)
672
+ max_end = sel_ranges[..., 1].max().item() # GPU sync only in debug
673
+ assert max_end <= current_pos + 1, (
674
+ f"Selection range violates causality: max_end={max_end} > current_pos+1={current_pos + 1}. "
675
+ f"Selection must not access future tokens."
676
+ )
677
+ # Update selection stats and observability: distance summary per step
678
+ try:
679
+ # Update per-step selection stats (decode has S==1)
680
+ self._update_sel_stats_from_ranges(sel_ranges)
681
+ starts = sel_ranges[..., 0].to(torch.int64)
682
+ ends = sel_ranges[..., 1].to(torch.int64)
683
+ lengths = (ends - starts).clamp_min(0)
684
+ dist = (kv.K_sel.shape[2] - 1) - starts
685
+ log(
686
+ "decode.select",
687
+ n_ranges=int(sel_ranges.shape[2]),
688
+ mean_len=float(lengths.float().mean().item()) if lengths.numel() else 0.0,
689
+ max_len=int(lengths.max().item()) if lengths.numel() else 0,
690
+ mean_dist=float(dist.float().mean().item()) if dist.numel() else 0.0,
691
+ max_dist=int(dist.max().item()) if dist.numel() else 0,
692
+ )
693
+ except Exception as e:
694
+ log("warn.decode.select_log_fail", error=str(e))
695
+ Q_t = Q[:, t]
696
+ K_sel_t = kv.K_sel
697
+ V_sel_t = kv.V_sel
698
+ # Selection attention: prefer Triton if enabled; else packed; fallback to gather
699
+ force_parity = env["force_parity"]
700
+ use_sel_pack = env["use_sel_pack"] and not force_parity
701
+ use_triton_sel = env["use_triton_sel"] and not force_parity
702
+ use_cuda_sel = env["use_cuda_sel"] and not force_parity
703
+ force_sel_mask = env.get("force_sel_mask", False) and not force_parity
704
+ if force_sel_mask:
705
+ try:
706
+ O_sel_bt = grouped_selection_attention_masked(
707
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
708
+ )
709
+ O_sel = O_sel_bt[:, 0]
710
+ log("decode.sel.path", path="masked_forced")
711
+ except Exception as e:
712
+ self._fallback_counters["selection_mask_fails"] += 1
713
+ self._fallback_counters["total_fallbacks"] += 1
714
+ log("warn.masked_selection_forced_fallback",
715
+ error=str(e),
716
+ step=t,
717
+ Q_shape=list(Q_t.shape),
718
+ K_shape=list(K_sel_t.shape),
719
+ V_shape=list(V_sel_t.shape),
720
+ ranges_shape=list(sel_ranges.shape) if sel_ranges is not None else None,
721
+ total_fails=self._fallback_counters["selection_mask_fails"])
722
+ O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
723
+ elif use_triton_sel:
724
+ try:
725
+ from nsa.kernels.triton_sel_kernel import selection_attention_triton
726
+
727
+ O_sel_bt = selection_attention_triton(
728
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
729
+ )
730
+ O_sel = O_sel_bt[:, 0]
731
+ log("decode.sel.path", path="triton")
732
+ except Exception as e:
733
+ # M8: Fallback counter - Triton selection failed
734
+ self._fallback_counters["selection_triton_fails"] += 1
735
+ self._fallback_counters["total_fallbacks"] += 1
736
+ log(
737
+ "warn.triton_selection_fallback",
738
+ error=str(e),
739
+ step=t,
740
+ Q_shape=list(Q_t.shape),
741
+ K_shape=list(K_sel_t.shape),
742
+ V_shape=list(V_sel_t.shape),
743
+ ranges_shape=list(sel_ranges.shape) if sel_ranges is not None else None,
744
+ total_fails=self._fallback_counters["selection_triton_fails"],
745
+ )
746
+ # Fallback to packed SDPA
747
+ O_sel_bt = grouped_selection_attention_packed(
748
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
749
+ )
750
+ O_sel = O_sel_bt[:, 0]
751
+ elif use_cuda_sel:
752
+ try:
753
+ from nsa.kernels.cuda_sel_kernel import selection_attention_cuda
754
+
755
+ O_sel_bt = selection_attention_cuda(
756
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
757
+ )
758
+ O_sel = O_sel_bt[:, 0]
759
+ except Exception as e:
760
+ # M8: Fallback counter - CUDA selection failed
761
+ self._fallback_counters["selection_cuda_fails"] += 1
762
+ self._fallback_counters["total_fallbacks"] += 1
763
+ log(
764
+ "warn.cuda_selection_fallback",
765
+ error=str(e),
766
+ step=t,
767
+ Q_shape=list(Q_t.shape),
768
+ K_shape=list(K_sel_t.shape),
769
+ V_shape=list(V_sel_t.shape),
770
+ ranges_shape=list(sel_ranges.shape) if sel_ranges is not None else None,
771
+ total_fails=self._fallback_counters["selection_cuda_fails"],
772
+ )
773
+ # Fallback to packed SDPA
774
+ O_sel_bt = grouped_selection_attention_packed(
775
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
776
+ )
777
+ O_sel = O_sel_bt[:, 0]
778
+ elif use_sel_pack:
779
+ try:
780
+ O_sel_bt = grouped_selection_attention_packed(
781
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
782
+ )
783
+ O_sel = O_sel_bt[:, 0]
784
+ log("decode.sel.path", path="packed")
785
+ except Exception as e:
786
+ # M8: Fallback counter - Packed selection failed
787
+ self._fallback_counters["selection_pack_fails"] += 1
788
+ self._fallback_counters["total_fallbacks"] += 1
789
+ log(
790
+ "warn.packed_selection_fallback",
791
+ error=str(e),
792
+ step=t,
793
+ Q_shape=list(Q_t.shape),
794
+ K_shape=list(K_sel_t.shape),
795
+ V_shape=list(V_sel_t.shape),
796
+ ranges_shape=list(sel_ranges.shape) if sel_ranges is not None else None,
797
+ total_fails=self._fallback_counters["selection_pack_fails"],
798
+ )
799
+ # Fallback to gather SDPA
800
+ O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
801
+ elif self._env_cache.get("use_sel_mask", False) and not force_parity:
802
+ try:
803
+ O_sel_bt = grouped_selection_attention_masked(
804
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
805
+ )
806
+ O_sel = O_sel_bt[:, 0]
807
+ log("decode.sel.path", path="masked")
808
+ except Exception as e:
809
+ # M8: Fallback counter - Masked selection failed
810
+ self._fallback_counters["selection_mask_fails"] += 1
811
+ self._fallback_counters["total_fallbacks"] += 1
812
+ log(
813
+ "warn.masked_selection_fallback",
814
+ error=str(e),
815
+ step=t,
816
+ Q_shape=list(Q_t.shape),
817
+ K_shape=list(K_sel_t.shape),
818
+ V_shape=list(V_sel_t.shape),
819
+ ranges_shape=list(sel_ranges.shape) if sel_ranges is not None else None,
820
+ total_fails=self._fallback_counters["selection_mask_fails"],
821
+ )
822
+ # Fallback to gather SDPA
823
+ O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
824
+ else:
825
+ O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
826
+ win_len = min(self.w, kv.K_win.shape[2])
827
+
828
+ # M8: Assert causal masking - sliding window bounds in decode
829
+ total_tokens = kv.K_win.shape[2]
830
+ start_idx = total_tokens - win_len
831
+ end_idx = total_tokens
832
+ assert start_idx >= 0, (
833
+ f"Sliding window start index negative: start_idx={start_idx}, "
834
+ f"total_tokens={total_tokens}, win_len={win_len}"
835
+ )
836
+ assert end_idx <= total_tokens, (
837
+ f"Sliding window end exceeds cache: end_idx={end_idx} > total_tokens={total_tokens}"
838
+ )
839
+ assert win_len <= self.w, (
840
+ f"Window length exceeds max: win_len={win_len} > self.w={self.w}"
841
+ )
842
+
843
+ K_w = kv.K_win[:, :, start_idx:end_idx, :]
844
+ V_w = kv.V_win[:, :, start_idx:end_idx, :]
845
+ use_flash = (
846
+ env["fa2_all_eff"] or env["fa2_win_eff"] or env["fa2_cmp_eff"]
847
+ ) and not force_parity
848
+ if use_flash and (env["fa2_all_eff"] or env["fa2_win_eff"]):
849
+ try:
850
+ O_win = sliding_window_attention_fa2_decode(Q_t, kv.K_win, kv.V_win, self.w)
851
+ except Exception as e:
852
+ # M8: Fallback counter - Sliding FA2 failed
853
+ self._fallback_counters["sliding_fa2_fails"] += 1
854
+ self._fallback_counters["total_fallbacks"] += 1
855
+ log(
856
+ "warn.sliding_fa2_fallback",
857
+ error=str(e),
858
+ total_fails=self._fallback_counters["sliding_fa2_fails"],
859
+ )
860
+ # Fallback to standard attention
861
+ O_win = attention_bgh(
862
+ Q_t.contiguous(), K_w.contiguous(), V_w.contiguous(), causal=True
863
+ )
864
+ else:
865
+ O_win = attention_bgh(
866
+ Q_t.contiguous(), K_w.contiguous(), V_w.contiguous(), causal=True
867
+ )
868
+ S_cmp_t = kv.K_cmp.shape[2]
869
+
870
+ # M8: Assert causal masking - compressed bounds in decode
871
+ assert S_cmp_t >= 0, f"Compressed cache size negative: S_cmp_t={S_cmp_t}"
872
+ assert S_cmp_t <= kv.K_cmp.shape[2], (
873
+ f"Compressed range exceeds cache: S_cmp_t={S_cmp_t} > cache_size={kv.K_cmp.shape[2]}"
874
+ )
875
+
876
+ if use_flash and (env["fa2_all_eff"] or env["fa2_cmp_eff"]):
877
+ try:
878
+ O_cmp = compressed_attention_fa2_decode(Q_t, kv.K_cmp, kv.V_cmp, S_cmp_t)
879
+ except Exception as e:
880
+ # M8: Fallback counter - Compressed FA2 failed
881
+ self._fallback_counters["compressed_fa2_fails"] += 1
882
+ self._fallback_counters["total_fallbacks"] += 1
883
+ log(
884
+ "warn.compressed_fa2_fallback",
885
+ error=str(e),
886
+ total_fails=self._fallback_counters["compressed_fa2_fails"],
887
+ )
888
+ # Fallback to standard attention
889
+ O_cmp = attention_bgh(
890
+ Q_t.contiguous(),
891
+ kv.K_cmp[:, :, :S_cmp_t, :].contiguous(),
892
+ kv.V_cmp[:, :, :S_cmp_t, :].contiguous(),
893
+ causal=True,
894
+ )
895
+ else:
896
+ O_cmp = attention_bgh(
897
+ Q_t.contiguous(),
898
+ kv.K_cmp[:, :, :S_cmp_t, :].contiguous(),
899
+ kv.V_cmp[:, :, :S_cmp_t, :].contiguous(),
900
+ causal=True,
901
+ )
902
+ # Preserve dtype for gate input
903
+ q_gp = Q_t.mean(dim=2, dtype=Q_t.dtype)
904
+ if self._env_cache.get("gate_compile", False):
905
+ try:
906
+ fused = self._gate_fused_bg
907
+ if fused is None:
908
+ fused = _fused_gate_combine_bg
909
+ try:
910
+ fused = torch.compile(fused, mode="reduce-overhead") # type: ignore[attr-defined]
911
+ except Exception:
912
+ pass
913
+ self._gate_fused_bg = fused
914
+ O = fused(
915
+ q_gp,
916
+ O_cmp,
917
+ O_sel,
918
+ O_win,
919
+ self.gate.fc1.weight,
920
+ self.gate.fc1.bias,
921
+ self.gate.fc2.weight,
922
+ self.gate.fc2.bias,
923
+ float(self.gate_temp),
924
+ )
925
+ except Exception:
926
+ gates = self.gate(q_gp, tau=self.gate_temp)
927
+ if self._env_cache.get("stopgrad_gates", False):
928
+ gates = gates.detach()
929
+ self._update_gate_stats(gates)
930
+ try:
931
+ log(
932
+ "decode.gates",
933
+ mean=gates.mean(dim=(-1, -2)).tolist()
934
+ if gates.dim() >= 2
935
+ else gates.mean().item(),
936
+ std=gates.std(dim=(-1, -2)).tolist()
937
+ if gates.dim() >= 2
938
+ else gates.std().item(),
939
+ )
940
+ except Exception as e:
941
+ log("warn.decode.gate_log_fail", error=str(e))
942
+ w_cmp = gates[..., 0:1].unsqueeze(-1)
943
+ w_sel = gates[..., 1:2].unsqueeze(-1)
944
+ w_win = gates[..., 2:3].unsqueeze(-1)
945
+ O = w_cmp * O_cmp + w_sel * O_sel + w_win * O_win
946
+ else:
947
+ gates = self.gate(q_gp, tau=self.gate_temp)
948
+ if self._env_cache.get("stopgrad_gates", False):
949
+ gates = gates.detach()
950
+ self._update_gate_stats(gates)
951
+ try:
952
+ log(
953
+ "decode.gates",
954
+ mean=gates.mean(dim=(-1, -2)).tolist()
955
+ if gates.dim() >= 2
956
+ else gates.mean().item(),
957
+ std=gates.std(dim=(-1, -2)).tolist()
958
+ if gates.dim() >= 2
959
+ else gates.std().item(),
960
+ )
961
+ except Exception as e:
962
+ log("warn.decode.gate_log_fail", error=str(e))
963
+ w_cmp = gates[..., 0:1].unsqueeze(-1)
964
+ w_sel = gates[..., 1:2].unsqueeze(-1)
965
+ w_win = gates[..., 2:3].unsqueeze(-1)
966
+ O = w_cmp * O_cmp + w_sel * O_sel + w_win * O_win
967
+ O_heads = O.reshape(B, self.n_heads, self.d_v)
968
+ out_t = self.out(O_heads.reshape(B, 1, -1))
969
+ outs.append(out_t)
970
+ out = torch.cat(outs, dim=1)
971
+ return out, kv
972
+
973
+ def _forward_prefill_batched(self, x: torch.Tensor, kv: NSA_KV) -> tuple[torch.Tensor, NSA_KV]:
974
+ """
975
+ Vectorized prefill path.
976
+
977
+ Steps:
978
+ - Projections with RoPE(Q); RoPE applied to K before ϕ for compressed branch
979
+ - Cache updates for selection/window/compressed
980
+ - Batched p_cmp → p_slc → p_grp; top‑n ranges for all t
981
+ - Branch attentions (masked/packed per env flags), gating, projection
982
+ """
983
+ B, S, _ = x.shape
984
+ # Projections
985
+ _nvtx = self._env_cache.get("nvtx", False)
986
+ if _nvtx:
987
+ try:
988
+ import torch as _t
989
+
990
+ _t.cuda.nvtx.range_push("projections+rope")
991
+ except Exception:
992
+ _nvtx = False
993
+ Q_lin = self._shape_q(self.W_Q(x), B, S) # [B,S,G,h,Dk]
994
+ assert Q_lin.shape[:2] == (B, S)
995
+ # Apply RoPE to Q
996
+ pos = torch.arange(S, device=x.device)
997
+ Q = apply_rope(
998
+ Q_lin.view(B, S, self.n_heads, self.d_k).reshape(B, S, self.n_heads * self.d_k),
999
+ pos,
1000
+ scale=getattr(self, "rope_scale", 1.0),
1001
+ )
1002
+ Q = Q.view(B, S, self.n_heads, self.d_k).view(
1003
+ B, S, self.n_kv_groups, self.h_per_group, self.d_k
1004
+ )
1005
+ # K/V projections per branch
1006
+ K_sel = self._shape_kv(self.W_K_sel(x), B, S)
1007
+ V_sel = self._shape_kv(self.W_V_sel(x), B, S)
1008
+ K_win = self._shape_kv(self.W_K_win(x), B, S)
1009
+ V_win = self._shape_kv(self.W_V_win(x), B, S)
1010
+ K_cmp_raw = self._shape_kv(self.W_K_cmp(x), B, S)
1011
+ V_cmp_raw = self._shape_kv(self.W_V_cmp(x), B, S)
1012
+ G = self.n_kv_groups
1013
+ assert K_sel.shape[:3] == (B, G, S) and V_sel.shape[:3] == (B, G, S)
1014
+ assert K_win.shape[:3] == (B, G, S) and V_win.shape[:3] == (B, G, S)
1015
+ assert K_cmp_raw.shape[:3] == (B, G, S) and V_cmp_raw.shape[:3] == (B, G, S)
1016
+
1017
+ # Apply RoPE to per-branch K tensors (Q already has RoPE applied)
1018
+ pos_k = torch.arange(S, device=x.device)
1019
+ K_sel = apply_rope(K_sel, pos_k, scale=getattr(self, "rope_scale", 1.0))
1020
+ K_win = apply_rope(K_win, pos_k, scale=getattr(self, "rope_scale", 1.0))
1021
+ if _nvtx:
1022
+ try:
1023
+ _t.cuda.nvtx.range_pop()
1024
+ except Exception:
1025
+ pass
1026
+
1027
+ # Update caches (prefill uses full sequence projections)
1028
+ kv.update_selection_raw(K_sel, V_sel)
1029
+ # Build/refresh meta for selection and compressed mapping
1030
+ kv.meta = build_block_meta(
1031
+ seq_len=S, l=self.l, d=self.d, l_sel=self.l_sel, n_sel=self.n_sel, w=self.w
1032
+ )
1033
+ kv.update_window(K_win, V_win, self.w)
1034
+ if self.phi_type == "mlp":
1035
+ K_cmp, V_cmp = self._phi_apply_seq(
1036
+ K_cmp_raw, V_cmp_raw, pos=torch.arange(S, device=x.device)
1037
+ )
1038
+ else:
1039
+ K_cmp, V_cmp = avg_pool_phi_rope_kv(
1040
+ K_cmp_raw, V_cmp_raw, self.l, self.d, pos=torch.arange(S, device=x.device)
1041
+ )
1042
+ kv.update_compressed(K_cmp, V_cmp, self.l, self.d)
1043
+
1044
+ # One-time SDPA backend audit (opt-in via env)
1045
+ try:
1046
+ if (not self._sdpa_audited) and os.getenv("NSA_SDPA_AUDIT", "0").lower() in (
1047
+ "1",
1048
+ "true",
1049
+ "yes",
1050
+ ):
1051
+ self._audit_sdpa_backends_once(
1052
+ Q[:, :1],
1053
+ K_sel[:, :, : max(1, S // 8), :],
1054
+ V_sel[:, :, : max(1, S // 8), :],
1055
+ K_win[:, :, : max(1, S // 8), :],
1056
+ V_win[:, :, : max(1, S // 8), :],
1057
+ )
1058
+ except Exception:
1059
+ pass
1060
+
1061
+ # Selection scores (batched)
1062
+ scale = 1.0 / (self.d_k**0.5)
1063
+ if _nvtx:
1064
+ try:
1065
+ _t.cuda.nvtx.range_push("pcmp_all")
1066
+ except Exception:
1067
+ pass
1068
+ p_cmp_all = compute_pcmp_all(Q, kv.K_cmp, scale) # [B,S,G,h,S_cmp]
1069
+ if _nvtx:
1070
+ try:
1071
+ _t.cuda.nvtx.range_pop()
1072
+ _t.cuda.nvtx.range_push("map_pcmp_to_pslc")
1073
+ except Exception:
1074
+ pass
1075
+ p_slc_all = map_pcmp_to_pslc_batched(p_cmp_all, kv.meta) # [B,S,G,h,S_sel]
1076
+
1077
+ # M8: Optional Eq.9 verification in batched prefill
1078
+ if self._env_cache.get("verify_eq9", False):
1079
+ is_equiv, details = verify_mapping_equivalence(p_cmp_all, kv.meta)
1080
+ if not is_equiv:
1081
+ log(
1082
+ "error.eq9_verification_failed_prefill",
1083
+ msg="Eq.9 mapping verification failed in batched prefill",
1084
+ **details,
1085
+ )
1086
+ p_grp_all = p_slc_all.sum(dim=3) # [B,S,G,S_sel]
1087
+ log(
1088
+ "prefill.scores",
1089
+ B=B,
1090
+ S=S,
1091
+ S_cmp=int(kv.K_cmp.shape[2]),
1092
+ S_sel=int(kv.meta.sel_starts.numel()),
1093
+ )
1094
+
1095
+ # Batched top‑n → ranges for all positions
1096
+ if _nvtx:
1097
+ try:
1098
+ _t.cuda.nvtx.range_push("topk+ranges")
1099
+ except Exception:
1100
+ pass
1101
+ sel_ranges_all = select_topn_ranges_batched(
1102
+ p_grp_all, kv.meta, self.n_sel, S, True, 2
1103
+ ) # [B,S,G,n,2]
1104
+ if _nvtx:
1105
+ try:
1106
+ _t.cuda.nvtx.range_pop()
1107
+ _t.cuda.nvtx.range_push("branch_attn+gate")
1108
+ except Exception:
1109
+ pass
1110
+ # Update selection statistics for this prefill batch
1111
+ self._update_sel_stats_from_ranges(sel_ranges_all)
1112
+ if _nvtx:
1113
+ try:
1114
+ _t.cuda.nvtx.range_pop()
1115
+ except Exception:
1116
+ pass
1117
+
1118
+ # M8: Assert causal masking for batched selection (GPU-sync gated)
1119
+ strict_asserts = self._env_cache.get("strict_asserts", False)
1120
+ if strict_asserts and sel_ranges_all.numel() > 0:
1121
+ for t in range(S):
1122
+ t_ranges = sel_ranges_all[:, t] # [B,G,n,2]
1123
+ if t_ranges.numel() > 0:
1124
+ max_end = t_ranges[..., 1].max().item()
1125
+ assert max_end <= t + 1, (
1126
+ f"Batched selection violates causality at t={t}: max_end={max_end} > t+1={t + 1}. "
1127
+ f"Selection ranges cannot access future tokens."
1128
+ )
1129
+ log("prefill.select", n_sel=self.n_sel, l_sel=self.l_sel, ranges=sel_ranges_all)
1130
+
1131
+ # Branch attentions in parallel (parity-first for cmp/win, with optional masked SDPA gates)
1132
+ force_parity = self._env_cache.get("force_parity", False)
1133
+ fa2_all = self._env_cache.get("fa2_all_eff", False)
1134
+ fa2_win = self._env_cache.get("fa2_win_eff", False)
1135
+ fa2_cmp = self._env_cache.get("fa2_cmp_eff", False)
1136
+ use_cmp_mask = self._env_cache.get("use_cmp_mask", True) and not force_parity
1137
+ if (fa2_all or fa2_cmp) and not force_parity:
1138
+ try:
1139
+ O_cmp = compressed_attention_fa2(Q, kv.K_cmp, kv.V_cmp, self.l, self.d)
1140
+ except Exception as e:
1141
+ # M8: Fallback counter - Compressed FA2 failed in prefill
1142
+ self._fallback_counters["compressed_fa2_fails"] += 1
1143
+ self._fallback_counters["total_fallbacks"] += 1
1144
+ log(
1145
+ "warn.compressed_fa2_prefill_fallback",
1146
+ error=str(e),
1147
+ total_fails=self._fallback_counters["compressed_fa2_fails"],
1148
+ )
1149
+ # Fallback to masked SDPA
1150
+ from nsa.core.attention_kernels import batched_causal_attention_compressed_masked
1151
+
1152
+ O_cmp = batched_causal_attention_compressed_masked(
1153
+ Q, kv.K_cmp, kv.V_cmp, self.l, self.d
1154
+ )
1155
+ elif use_cmp_mask:
1156
+ from nsa.core.attention_kernels import batched_causal_attention_compressed_masked
1157
+
1158
+ O_cmp = batched_causal_attention_compressed_masked(
1159
+ Q, kv.K_cmp, kv.V_cmp, self.l, self.d
1160
+ )
1161
+ else:
1162
+ # Compressed per-t using the same kernel as sequential
1163
+ O_cmp = torch.zeros(
1164
+ (B, S, self.n_kv_groups, self.h_per_group, self.d_v),
1165
+ device=x.device,
1166
+ dtype=V_cmp.dtype,
1167
+ )
1168
+ S_cmp_full = kv.K_cmp.shape[2]
1169
+ for t in range(S):
1170
+ L = 0 if (t + 1) < self.l else min(((t + 1 - self.l) // self.d) + 1, S_cmp_full)
1171
+
1172
+ # M8: Assert causal masking - compressed tokens must respect position bounds
1173
+ if L > 0:
1174
+ # Check that compressed range doesn't exceed causal bounds
1175
+ assert L <= S_cmp_full, (
1176
+ f"Compressed range exceeds cache: L={L} > S_cmp_full={S_cmp_full} at t={t}"
1177
+ )
1178
+ # Verify causal constraint: at position t, can only see compressed tokens
1179
+ # that represent original positions up to t
1180
+ max_allowed_L = ((t + 1 - self.l) // self.d) + 1 if (t + 1) >= self.l else 0
1181
+ assert L <= max_allowed_L, (
1182
+ f"Compressed range violates causality: L={L} > max_allowed_L={max_allowed_L} "
1183
+ f"at t={t}. Compressed tokens represent future positions."
1184
+ )
1185
+
1186
+ q_t = Q[:, t].contiguous()
1187
+ k_t = kv.K_cmp[:, :, :L, :].contiguous()
1188
+ v_t = kv.V_cmp[:, :, :L, :].contiguous()
1189
+ O_cmp[:, t] = attention_bgh(q_t, k_t, v_t, causal=True)
1190
+ # Strict finite check and fallback
1191
+ if strict_asserts and not torch.isfinite(O_cmp).all():
1192
+ from nsa.core.attention_kernels import batched_causal_attention_compressed_masked
1193
+
1194
+ log("warn.prefill_cmp_nonfinite_fallback")
1195
+ O_cmp = batched_causal_attention_compressed_masked(
1196
+ Q, kv.K_cmp, kv.V_cmp, self.l, self.d
1197
+ )
1198
+ log("prefill.cmp", O_cmp=O_cmp)
1199
+
1200
+ # Selected ranges attention (prefer Triton if enabled; else packed/gather)
1201
+ use_sel_pack = self._env_cache.get("use_sel_pack", True) and not force_parity
1202
+ use_sel_varlen = self._env_cache.get("use_sel_varlen", False) and not force_parity
1203
+ use_triton_sel = (
1204
+ self._env_cache.get("use_triton_sel", False) or self.use_triton_sel and not force_parity
1205
+ )
1206
+ force_sel_mask = self._env_cache.get("force_sel_mask", False) and not force_parity
1207
+ if force_sel_mask:
1208
+ try:
1209
+ O_sel = grouped_selection_attention_masked(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1210
+ log("prefill.sel.path", path="masked_forced")
1211
+ except Exception as e:
1212
+ # Fallback to gather SDPA
1213
+ self._fallback_counters["selection_mask_fails"] += 1
1214
+ self._fallback_counters["total_fallbacks"] += 1
1215
+ log("warn.masked_selection_prefill_forced_fallback",
1216
+ error=str(e),
1217
+ Q_shape=list(Q.shape) if hasattr(Q, 'shape') else list(Q_t.shape),
1218
+ K_shape=list(kv.K_sel.shape) if hasattr(kv, 'K_sel') else list(K_sel_t.shape),
1219
+ V_shape=list(kv.V_sel.shape) if hasattr(kv, 'V_sel') else list(V_sel_t.shape),
1220
+ ranges_shape=list(sel_ranges_all.shape) if 'sel_ranges_all' in locals() else list(sel_ranges.shape) if sel_ranges is not None else None,
1221
+ total_fails=self._fallback_counters["selection_mask_fails"])
1222
+ O_sel = grouped_selection_attention(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1223
+ elif use_triton_sel:
1224
+ try:
1225
+ from nsa.kernels.triton_sel_kernel import selection_attention_triton
1226
+
1227
+ O_sel = selection_attention_triton(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1228
+ log("prefill.sel.path", path="triton")
1229
+ except Exception as e:
1230
+ # M8: Fallback counter - Triton selection failed in prefill
1231
+ self._fallback_counters["selection_triton_fails"] += 1
1232
+ self._fallback_counters["total_fallbacks"] += 1
1233
+ log(
1234
+ "warn.triton_selection_prefill_fallback",
1235
+ error=str(e),
1236
+ Q_shape=list(Q.shape) if hasattr(Q, 'shape') else list(Q_t.shape),
1237
+ K_shape=list(kv.K_sel.shape) if hasattr(kv, 'K_sel') else list(K_sel_t.shape),
1238
+ V_shape=list(kv.V_sel.shape) if hasattr(kv, 'V_sel') else list(V_sel_t.shape),
1239
+ ranges_shape=list(sel_ranges_all.shape) if 'sel_ranges_all' in locals() else list(sel_ranges.shape) if 'sel_ranges' in locals() and sel_ranges is not None else None,
1240
+ total_fails=self._fallback_counters["selection_triton_fails"],
1241
+ )
1242
+ # Fallback to packed SDPA
1243
+ O_sel = grouped_selection_attention_packed(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1244
+ elif use_sel_varlen:
1245
+ try:
1246
+ from nsa.core.attention_kernels import selection_attention_varlen_all
1247
+
1248
+ O_sel = selection_attention_varlen_all(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1249
+ log("prefill.sel.path", path="varlen")
1250
+ except Exception as e:
1251
+ # Fallback counter reuse for selection pack failures
1252
+ self._fallback_counters["selection_pack_fails"] += 1
1253
+ self._fallback_counters["total_fallbacks"] += 1
1254
+ log(
1255
+ "warn.selection_varlen_prefill_fallback",
1256
+ error=str(e),
1257
+ total_fails=self._fallback_counters["selection_pack_fails"],
1258
+ )
1259
+ # Fallback to packed SDPA
1260
+ O_sel = grouped_selection_attention_packed(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1261
+ log("prefill.sel.path", path="packed")
1262
+ elif use_sel_pack:
1263
+ try:
1264
+ O_sel = grouped_selection_attention_packed(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1265
+ except Exception as e:
1266
+ # M8: Fallback counter - Packed selection failed in prefill
1267
+ self._fallback_counters["selection_pack_fails"] += 1
1268
+ self._fallback_counters["total_fallbacks"] += 1
1269
+ log(
1270
+ "warn.packed_selection_prefill_fallback",
1271
+ error=str(e),
1272
+ total_fails=self._fallback_counters["selection_pack_fails"],
1273
+ )
1274
+ # Fallback to gather SDPA
1275
+ O_sel = grouped_selection_attention(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1276
+ elif self._env_cache.get("use_sel_mask", False):
1277
+ try:
1278
+ O_sel = grouped_selection_attention_masked(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1279
+ log("prefill.sel.path", path="masked")
1280
+ except Exception as e:
1281
+ # M8: Fallback counter - Masked selection failed in prefill
1282
+ self._fallback_counters["selection_mask_fails"] += 1
1283
+ self._fallback_counters["total_fallbacks"] += 1
1284
+ log(
1285
+ "warn.masked_selection_prefill_fallback",
1286
+ error=str(e),
1287
+ total_fails=self._fallback_counters["selection_mask_fails"],
1288
+ )
1289
+ # Fallback to gather SDPA
1290
+ O_sel = grouped_selection_attention(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1291
+ else:
1292
+ O_sel = grouped_selection_attention(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1293
+ log("prefill.sel.path", path="gather")
1294
+ if strict_asserts and not torch.isfinite(O_sel).all():
1295
+ log("warn.prefill_sel_nonfinite_fallback")
1296
+ O_sel = grouped_selection_attention(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
1297
+ log("prefill.sel", O_sel=O_sel)
1298
+
1299
+ use_win_mask = self._env_cache.get("use_win_mask", True) and not force_parity
1300
+ if (fa2_all or fa2_win) and not force_parity:
1301
+ try:
1302
+ O_win = sliding_window_attention_fa2(Q, K_win, V_win, self.w)
1303
+ except Exception as e:
1304
+ # M8: Fallback counter - Sliding FA2 failed in prefill
1305
+ self._fallback_counters["sliding_fa2_fails"] += 1
1306
+ self._fallback_counters["total_fallbacks"] += 1
1307
+ log(
1308
+ "warn.sliding_fa2_prefill_fallback",
1309
+ error=str(e),
1310
+ total_fails=self._fallback_counters["sliding_fa2_fails"],
1311
+ )
1312
+ # Fallback to masked SDPA
1313
+ from nsa.core.attention_kernels import sliding_window_attention
1314
+
1315
+ O_win = sliding_window_attention(Q, K_win, V_win, self.w)
1316
+ elif use_win_mask:
1317
+ from nsa.core.attention_kernels import sliding_window_attention
1318
+
1319
+ O_win = sliding_window_attention(Q, K_win, V_win, self.w)
1320
+ else:
1321
+ # Sliding per-t using the same kernel as sequential
1322
+ O_win = torch.zeros(
1323
+ (B, S, self.n_kv_groups, self.h_per_group, self.d_v),
1324
+ device=x.device,
1325
+ dtype=V_win.dtype,
1326
+ )
1327
+ for t in range(S):
1328
+ end = t + 1
1329
+ start = max(0, end - self.w)
1330
+
1331
+ # M8: Assert causal masking - sliding window must not exceed current position
1332
+ assert end <= t + 1, (
1333
+ f"Sliding window violates causality: end={end} > t+1={t + 1} at position t={t}. "
1334
+ f"This indicates window is accessing future tokens."
1335
+ )
1336
+ assert start <= end, (
1337
+ f"Sliding window has invalid range: start={start} > end={end} at position t={t}."
1338
+ )
1339
+
1340
+ q_t = Q[:, t].contiguous()
1341
+ k_t = K_win[:, :, start:end, :].contiguous()
1342
+ v_t = V_win[:, :, start:end, :].contiguous()
1343
+ O_win[:, t] = attention_bgh(q_t, k_t, v_t, causal=True)
1344
+ if strict_asserts and not torch.isfinite(O_win).all():
1345
+ from nsa.core.attention_kernels import sliding_window_attention
1346
+
1347
+ log("warn.prefill_win_nonfinite_fallback")
1348
+ O_win = sliding_window_attention(Q, K_win, V_win, self.w)
1349
+ log("prefill.win", O_win=O_win)
1350
+
1351
+ # Gates and combine
1352
+ q_gp = Q.mean(dim=3) # [B,S,G,Dk]
1353
+ if self._env_cache.get("gate_compile", False):
1354
+ try:
1355
+ fused = self._gate_fused_bsg
1356
+ if fused is None:
1357
+ fused = _fused_gate_combine_bsg
1358
+ try:
1359
+ fused = torch.compile(fused, mode="reduce-overhead") # type: ignore[attr-defined]
1360
+ except Exception:
1361
+ pass
1362
+ self._gate_fused_bsg = fused
1363
+ O = fused(
1364
+ q_gp,
1365
+ O_cmp,
1366
+ O_sel,
1367
+ O_win,
1368
+ self.gate.fc1.weight,
1369
+ self.gate.fc1.bias,
1370
+ self.gate.fc2.weight,
1371
+ self.gate.fc2.bias,
1372
+ float(self.gate_temp),
1373
+ )
1374
+ except Exception:
1375
+ gates = self.gate(q_gp.reshape(B * S * self.n_kv_groups, self.d_k), tau=self.gate_temp)
1376
+ if self._env_cache.get("stopgrad_gates", False):
1377
+ gates = gates.detach()
1378
+ gates = gates.view(B, S, self.n_kv_groups, 3) # [B,S,G,3]
1379
+ self._update_gate_stats(gates)
1380
+ w_cmp = gates[..., 0:1].unsqueeze(3)
1381
+ w_sel = gates[..., 1:2].unsqueeze(3)
1382
+ w_win = gates[..., 2:3].unsqueeze(3)
1383
+ O = w_cmp * O_cmp + w_sel * O_sel + w_win * O_win # [B,S,G,h,Dv]
1384
+ else:
1385
+ gates = self.gate(q_gp.reshape(B * S * self.n_kv_groups, self.d_k), tau=self.gate_temp)
1386
+ if self._env_cache.get("stopgrad_gates", False):
1387
+ gates = gates.detach()
1388
+ gates = gates.view(B, S, self.n_kv_groups, 3) # [B,S,G,3]
1389
+ self._update_gate_stats(gates)
1390
+ w_cmp = gates[..., 0:1].unsqueeze(3)
1391
+ w_sel = gates[..., 1:2].unsqueeze(3)
1392
+ w_win = gates[..., 2:3].unsqueeze(3)
1393
+ O = w_cmp * O_cmp + w_sel * O_sel + w_win * O_win # [B,S,G,h,Dv]
1394
+
1395
+ # Output projection
1396
+ O_heads = O.reshape(B, S, self.n_kv_groups * self.h_per_group, self.d_v)
1397
+ out = self.out(O_heads.reshape(B, S, -1))
1398
+ log("prefill.out", out=out)
1399
+
1400
+ # Optional debug compare: sequential-style per-token recompute to measure MAE
1401
+ if self._env_cache.get("debug_compare", False):
1402
+ with torch.no_grad():
1403
+ # Compressed per-token recompute
1404
+ O_cmp_seq = torch.zeros_like(O_cmp)
1405
+ S_cmp = kv.K_cmp.shape[2]
1406
+ for t in range(S):
1407
+ L = 0 if (t + 1) < self.l else min(((t + 1 - self.l) // self.d) + 1, S_cmp)
1408
+
1409
+ # M8: Assert causal masking in debug recompute
1410
+ if L > 0:
1411
+ assert L <= S_cmp, (
1412
+ f"Debug compressed range exceeds cache: L={L} > S_cmp={S_cmp} at t={t}"
1413
+ )
1414
+
1415
+ q_t = Q[:, t].contiguous()
1416
+ k_t = kv.K_cmp[:, :, :L, :].contiguous()
1417
+ v_t = kv.V_cmp[:, :, :L, :].contiguous()
1418
+ O_cmp_seq[:, t] = attention_bgh(q_t, k_t, v_t, causal=True)
1419
+ cmp_mae = (O_cmp - O_cmp_seq).abs().mean().item()
1420
+ print(f"NSA-DBG cmp_mae={cmp_mae:.6e}")
1421
+
1422
+ # Sliding per-token recompute
1423
+ O_win_seq = torch.zeros_like(O_win)
1424
+ for t in range(S):
1425
+ end = t + 1
1426
+ start = max(0, end - self.w)
1427
+ q_t = Q[:, t].contiguous()
1428
+ k_t = K_win[:, :, start:end, :].contiguous()
1429
+ v_t = V_win[:, :, start:end, :].contiguous()
1430
+ O_win_seq[:, t] = attention_bgh(q_t, k_t, v_t, causal=True)
1431
+ win_mae = (O_win - O_win_seq).abs().mean().item()
1432
+ print(f"NSA-DBG win_mae={win_mae:.6e}")
1433
+
1434
+ # Final output recompute using seq per-branch
1435
+ w_cmp_dbg = gates[..., 0:1].unsqueeze(-1)
1436
+ w_sel_dbg = gates[..., 1:2].unsqueeze(-1)
1437
+ w_win_dbg = gates[..., 2:3].unsqueeze(-1)
1438
+ O_seq = w_cmp_dbg * O_cmp_seq + w_sel_dbg * O_sel + w_win_dbg * O_win_seq
1439
+ O_heads_seq = O_seq.reshape(B, S, self.n_kv_groups * self.h_per_group, self.d_v)
1440
+ out_seq = self.out(O_heads_seq.reshape(B, S, -1))
1441
+ out_mae = (out - out_seq).abs().mean().item()
1442
+ print(f"NSA-DBG out_mae={out_mae:.6e}")
1443
+ return out, kv
1444
+
1445
+ def _audit_sdpa_backends_once(
1446
+ self,
1447
+ Q: torch.Tensor, # [B,1,G,h,Dk]
1448
+ K_sel: torch.Tensor, # [B,G,S,Dk]
1449
+ V_sel: torch.Tensor, # [B,G,S,Dv]
1450
+ K_win: torch.Tensor, # [B,G,S,Dk]
1451
+ V_win: torch.Tensor, # [B,G,S,Dv]
1452
+ ) -> None:
1453
+ if self._sdpa_audited:
1454
+ return
1455
+ try:
1456
+ from torch.nn.attention import sdpa_kernel
1457
+ except Exception:
1458
+ # Older torch, skip audit
1459
+ self._sdpa_audited = True
1460
+ return
1461
+ B = Q.shape[0]
1462
+ G = self.n_kv_groups
1463
+ h = self.h_per_group
1464
+ # Prepare a small representative slice per branch
1465
+ q = Q[:, 0] # [B,G,h,Dk]
1466
+ # Ensure contiguity
1467
+ q = q.contiguous()
1468
+ ks = K_sel.contiguous()
1469
+ vs = V_sel.contiguous()
1470
+ kw = K_win.contiguous()
1471
+ vw = V_win.contiguous()
1472
+
1473
+ def _probe(tag: str, k: torch.Tensor, v: torch.Tensor) -> str:
1474
+ try:
1475
+ with sdpa_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
1476
+ q2 = q.reshape(B * G * h, 1, self.d_k)
1477
+ k2 = (
1478
+ k.unsqueeze(2)
1479
+ .expand(B, G, h, k.shape[2], self.d_k)
1480
+ .reshape(B * G * h, k.shape[2], self.d_k)
1481
+ )
1482
+ v2 = (
1483
+ v.unsqueeze(2)
1484
+ .expand(B, G, h, v.shape[2], self.d_v)
1485
+ .reshape(B * G * h, v.shape[2], self.d_v)
1486
+ )
1487
+ _ = F.scaled_dot_product_attention(
1488
+ q2.contiguous(), k2.contiguous(), v2.contiguous(), is_causal=True
1489
+ )
1490
+ return "flash"
1491
+ except Exception:
1492
+ return "fallback"
1493
+
1494
+ try:
1495
+ b_sel = _probe("cmp/win(sel)", ks, vs)
1496
+ b_win = _probe("win", kw, vw)
1497
+ log("sdpa.audit", sel=b_sel, win=b_win)
1498
+ except Exception:
1499
+ pass
1500
+ self._sdpa_audited = True
1501
+
1502
+ def _forward_prefill_via_decode(
1503
+ self, x: torch.Tensor, kv: NSA_KV
1504
+ ) -> tuple[torch.Tensor, NSA_KV]:
1505
+ """Prefill by stepping decode one token at a time.
1506
+
1507
+ This path avoids recursion back into prefill and guarantees progress.
1508
+ """
1509
+ B, S, _ = x.shape
1510
+ outs = []
1511
+ for t in range(S):
1512
+ out_t, kv = self.forward(x[:, t : t + 1], kv, prefill=False)
1513
+ outs.append(out_t)
1514
+ return torch.cat(outs, dim=1), kv
1515
+
1516
+ def _forward_prefill_sequential(
1517
+ self, x: torch.Tensor, kv: NSA_KV
1518
+ ) -> tuple[torch.Tensor, NSA_KV]:
1519
+ """
1520
+ Reference prefill path (sequential per‑token), used for parity checks.
1521
+ """
1522
+ B, S, _ = x.shape
1523
+ # Projections
1524
+ Q_lin = self._shape_q(self.W_Q(x), B, S) # [B,S,G,h,Dk]
1525
+ pos = torch.arange(S, device=x.device)
1526
+ Q = apply_rope(
1527
+ Q_lin.view(B, S, self.n_heads, self.d_k).reshape(B, S, self.n_heads * self.d_k),
1528
+ pos,
1529
+ scale=getattr(self, "rope_scale", 1.0),
1530
+ )
1531
+ Q = Q.view(B, S, self.n_heads, self.d_k).view(
1532
+ B, S, self.n_kv_groups, self.h_per_group, self.d_k
1533
+ )
1534
+ K_sel = self._shape_kv(self.W_K_sel(x), B, S)
1535
+ V_sel = self._shape_kv(self.W_V_sel(x), B, S)
1536
+ K_win = self._shape_kv(self.W_K_win(x), B, S)
1537
+ V_win = self._shape_kv(self.W_V_win(x), B, S)
1538
+ K_cmp_raw = self._shape_kv(self.W_K_cmp(x), B, S)
1539
+ V_cmp_raw = self._shape_kv(self.W_V_cmp(x), B, S)
1540
+
1541
+ # Apply RoPE to per-branch K tensors to align with batched path
1542
+ pos_k = torch.arange(S, device=x.device)
1543
+ K_sel = apply_rope(K_sel, pos_k, scale=getattr(self, "rope_scale", 1.0))
1544
+ K_win = apply_rope(K_win, pos_k, scale=getattr(self, "rope_scale", 1.0))
1545
+
1546
+ kv.update_selection_raw(K_sel, V_sel)
1547
+ kv.meta = build_block_meta(
1548
+ seq_len=S, l=self.l, d=self.d, l_sel=self.l_sel, n_sel=self.n_sel, w=self.w
1549
+ )
1550
+ kv.update_window(K_win, V_win, self.w)
1551
+ if self.phi_type == "mlp":
1552
+ K_cmp, V_cmp = self._phi_apply_seq(
1553
+ K_cmp_raw, V_cmp_raw, pos=torch.arange(S, device=x.device)
1554
+ )
1555
+ else:
1556
+ K_cmp, V_cmp = avg_pool_phi_rope_kv(
1557
+ K_cmp_raw, V_cmp_raw, self.l, self.d, pos=torch.arange(S, device=x.device)
1558
+ )
1559
+ kv.update_compressed(K_cmp, V_cmp, self.l, self.d)
1560
+
1561
+ # Precompute p_grp_all batched for reuse per t
1562
+ scale = 1.0 / (self.d_k**0.5)
1563
+ p_cmp_all = compute_pcmp_all(Q, kv.K_cmp, scale) # [B,S,G,h,S_cmp]
1564
+ p_slc_all = map_pcmp_to_pslc_batched(p_cmp_all, kv.meta) # [B,S,G,h,S_sel]
1565
+ p_grp_all = p_slc_all.sum(dim=3) # [B,S,G,S_sel]
1566
+
1567
+ outs = []
1568
+ sel_ranges_accum: List[torch.Tensor] = []
1569
+ for t in range(S):
1570
+ p_grp = p_grp_all[:, t] # [B,G,S_sel]
1571
+ sel_ranges = select_topn_ranges(p_grp, kv.meta, self.n_sel, t, True, 2)
1572
+ sel_ranges_accum.append(sel_ranges)
1573
+ Q_t = Q[:, t]
1574
+ K_sel_t = kv.K_sel[:, :, : t + 1, :]
1575
+ V_sel_t = kv.V_sel[:, :, : t + 1, :]
1576
+ # Selection attention routing (mirror decode/batched semantics)
1577
+ force_parity = self._env_cache.get("force_parity", False)
1578
+ use_sel_pack = self._env_cache.get("use_sel_pack", True) and not force_parity
1579
+ use_triton_sel = self._env_cache.get("use_triton_sel", False) and not force_parity
1580
+ use_cuda_sel = self._env_cache.get("use_cuda_sel", False) and not force_parity
1581
+ force_sel_mask = self._env_cache.get("force_sel_mask", False) and not force_parity
1582
+ if force_sel_mask:
1583
+ try:
1584
+ O_sel_bt = grouped_selection_attention_masked(
1585
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
1586
+ )
1587
+ O_sel = O_sel_bt[:, 0]
1588
+ log("prefill.sel.path", path="masked_forced")
1589
+ except Exception as e:
1590
+ self._fallback_counters["selection_mask_fails"] += 1
1591
+ self._fallback_counters["total_fallbacks"] += 1
1592
+ log("warn.masked_selection_prefill_forced_fallback",
1593
+ error=str(e),
1594
+ Q_shape=list(Q.shape) if hasattr(Q, 'shape') else list(Q_t.shape),
1595
+ K_shape=list(kv.K_sel.shape) if hasattr(kv, 'K_sel') else list(K_sel_t.shape),
1596
+ V_shape=list(kv.V_sel.shape) if hasattr(kv, 'V_sel') else list(V_sel_t.shape),
1597
+ ranges_shape=list(sel_ranges_all.shape) if 'sel_ranges_all' in locals() else list(sel_ranges.shape) if sel_ranges is not None else None,
1598
+ total_fails=self._fallback_counters["selection_mask_fails"])
1599
+ O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
1600
+ elif use_triton_sel:
1601
+ try:
1602
+ from nsa.kernels.triton_sel_kernel import selection_attention_triton
1603
+
1604
+ O_sel_bt = selection_attention_triton(
1605
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
1606
+ )
1607
+ O_sel = O_sel_bt[:, 0]
1608
+ log("prefill.sel.path", path="triton")
1609
+ except Exception as e:
1610
+ # Fallback counter - Triton selection failed (sequential prefill)
1611
+ self._fallback_counters["selection_triton_fails"] += 1
1612
+ self._fallback_counters["total_fallbacks"] += 1
1613
+ log(
1614
+ "warn.triton_selection_prefill_fallback",
1615
+ error=str(e),
1616
+ total_fails=self._fallback_counters["selection_triton_fails"],
1617
+ )
1618
+ # Fallback to packed SDPA
1619
+ O_sel_bt = grouped_selection_attention_packed(
1620
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
1621
+ )
1622
+ O_sel = O_sel_bt[:, 0]
1623
+ elif use_cuda_sel:
1624
+ try:
1625
+ from nsa.kernels.cuda_sel_kernel import selection_attention_cuda
1626
+
1627
+ O_sel_bt = selection_attention_cuda(
1628
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
1629
+ )
1630
+ O_sel = O_sel_bt[:, 0]
1631
+ except Exception as e:
1632
+ # Fallback counter - CUDA selection failed (sequential prefill)
1633
+ self._fallback_counters["selection_cuda_fails"] += 1
1634
+ self._fallback_counters["total_fallbacks"] += 1
1635
+ log(
1636
+ "warn.cuda_selection_prefill_fallback",
1637
+ error=str(e),
1638
+ total_fails=self._fallback_counters["selection_cuda_fails"],
1639
+ )
1640
+ # Fallback to packed SDPA
1641
+ O_sel_bt = grouped_selection_attention_packed(
1642
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
1643
+ )
1644
+ O_sel = O_sel_bt[:, 0]
1645
+ elif use_sel_pack:
1646
+ try:
1647
+ O_sel_bt = grouped_selection_attention_packed(
1648
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
1649
+ )
1650
+ O_sel = O_sel_bt[:, 0]
1651
+ log("prefill.sel.path", path="packed")
1652
+ except Exception as e:
1653
+ # Fallback counter - Packed selection failed (sequential prefill)
1654
+ self._fallback_counters["selection_pack_fails"] += 1
1655
+ self._fallback_counters["total_fallbacks"] += 1
1656
+ log(
1657
+ "warn.packed_selection_prefill_fallback",
1658
+ error=str(e),
1659
+ total_fails=self._fallback_counters["selection_pack_fails"],
1660
+ )
1661
+ # Fallback to gather SDPA
1662
+ O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
1663
+ elif self._env_cache.get("use_sel_mask", False) and not force_parity:
1664
+ try:
1665
+ O_sel_bt = grouped_selection_attention_masked(
1666
+ Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
1667
+ )
1668
+ O_sel = O_sel_bt[:, 0]
1669
+ log("prefill.sel.path", path="masked")
1670
+ except Exception as e:
1671
+ # Fallback counter - Masked selection failed (sequential prefill)
1672
+ self._fallback_counters["selection_mask_fails"] += 1
1673
+ self._fallback_counters["total_fallbacks"] += 1
1674
+ log(
1675
+ "warn.masked_selection_prefill_fallback",
1676
+ error=str(e),
1677
+ total_fails=self._fallback_counters["selection_mask_fails"],
1678
+ )
1679
+ # Fallback to gather SDPA
1680
+ O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
1681
+ else:
1682
+ O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
1683
+ win_len = min(self.w, t + 1)
1684
+ K_w = kv.K_win[:, :, t + 1 - win_len : t + 1, :]
1685
+ V_w = kv.V_win[:, :, t + 1 - win_len : t + 1, :]
1686
+ O_win = attention_bgh(Q_t.contiguous(), K_w.contiguous(), V_w.contiguous(), causal=True)
1687
+ S_cmp_t = 0 if (t + 1) < self.l else (t + 1 - self.l) // self.d + 1
1688
+ O_cmp = attention_bgh(
1689
+ Q_t.contiguous(),
1690
+ kv.K_cmp[:, :, :S_cmp_t, :].contiguous(),
1691
+ kv.V_cmp[:, :, :S_cmp_t, :].contiguous(),
1692
+ causal=True,
1693
+ )
1694
+ q_gp = Q_t.mean(dim=2, dtype=Q_t.dtype)
1695
+ gates = self.gate(q_gp, tau=self.gate_temp)
1696
+ if self._env_cache.get("stopgrad_gates", False):
1697
+ gates = gates.detach()
1698
+
1699
+ # Update gate statistics for M8 monitoring (accumulate across steps)
1700
+ self._update_gate_stats(gates)
1701
+
1702
+ w_cmp = gates[..., 0:1].unsqueeze(-1)
1703
+ w_sel = gates[..., 1:2].unsqueeze(-1)
1704
+ w_win = gates[..., 2:3].unsqueeze(-1)
1705
+ O = w_cmp * O_cmp + w_sel * O_sel + w_win * O_win
1706
+ O_heads = O.reshape(B, self.n_heads, self.d_v)
1707
+ out_t = self.out(O_heads.reshape(B, 1, -1))
1708
+ outs.append(out_t)
1709
+ out = torch.cat(outs, dim=1)
1710
+ # Aggregate selection stats across all t in this prefill (sequential path)
1711
+ try:
1712
+ if sel_ranges_accum:
1713
+ # Stack to [T,B,G,n,2] then permute to [B,T,G,n,2]
1714
+ rs = torch.stack(sel_ranges_accum, dim=0).permute(1, 0, 2, 3, 4)
1715
+ self._update_sel_stats_from_ranges(rs)
1716
+ except Exception:
1717
+ pass
1718
+ return out, kv
1719
+
1720
+ def _sdpa_full(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
1721
+ # Q: [B,G,h,Dk]; K/V: [B,G,S,D*] -> out [B,G,h,Dv]
1722
+ B, G, h, Dk = Q.shape
1723
+ S = K.shape[2]
1724
+ q = Q.reshape(B * G * h, 1, Dk).contiguous()
1725
+ k = K.unsqueeze(2).expand(B, G, h, S, Dk).reshape(B * G * h, S, Dk).contiguous()
1726
+ v = (
1727
+ V.unsqueeze(2)
1728
+ .expand(B, G, h, S, V.shape[-1])
1729
+ .reshape(B * G * h, S, V.shape[-1])
1730
+ .contiguous()
1731
+ )
1732
+ attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
1733
+ o = attn.squeeze(1).reshape(B, G, h, -1)
1734
+ return o
1735
+
1736
+ def _phi_apply_seq(
1737
+ self, K_raw: torch.Tensor, V_raw: torch.Tensor, pos: torch.Tensor
1738
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1739
+ """Apply learnable ϕ over the full sequence using depthwise Conv1d initialized to avg.
1740
+ Expects K_raw,V_raw: [B,G,S,D*]; returns [B,G,S_cmp,D*].
1741
+ """
1742
+ assert self.phi_k_conv is not None and self.phi_v_conv is not None
1743
+ B, G, S, Dk = K_raw.shape
1744
+ Dv = V_raw.shape[-1]
1745
+ K_rope = apply_rope(K_raw, pos, scale=getattr(self, "rope_scale", 1.0))
1746
+ Kx = K_rope.permute(0, 1, 3, 2).reshape(B * G, Dk, S)
1747
+ Vx = V_raw.permute(0, 1, 3, 2).reshape(B * G, Dv, S)
1748
+ Kc = self.phi_k_conv(Kx)
1749
+ Vc = self.phi_v_conv(Vx)
1750
+ S_cmp = Kc.shape[-1]
1751
+ K_cmp = Kc.reshape(B, G, Dk, S_cmp).permute(0, 1, 3, 2).contiguous()
1752
+ V_cmp = Vc.reshape(B, G, Dv, S_cmp).permute(0, 1, 3, 2).contiguous()
1753
+ return K_cmp, V_cmp
1754
+
1755
+ def _phi_apply_last(
1756
+ self, K_last: torch.Tensor, V_last: torch.Tensor, pos_last: torch.Tensor
1757
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1758
+ """Emit a single compressed token from the last l raw tokens using Conv1d with kernel=l,stride=d.
1759
+ Inputs: [B,G,l,D*] -> Outputs: [B,G,1,D*].
1760
+ """
1761
+ assert self.phi_k_conv is not None and self.phi_v_conv is not None
1762
+ B, G, lwin, Dk = K_last.shape
1763
+ Dv = V_last.shape[-1]
1764
+ assert lwin == self.l, "decode emission expects exactly l tokens"
1765
+ K_rope = apply_rope(K_last, pos_last, scale=getattr(self, "rope_scale", 1.0))
1766
+ Kx = K_rope.permute(0, 1, 3, 2).reshape(B * G, Dk, lwin)
1767
+ Vx = V_last.permute(0, 1, 3, 2).reshape(B * G, Dv, lwin)
1768
+ Kc = self.phi_k_conv(Kx)
1769
+ Vc = self.phi_v_conv(Vx)
1770
+ K_cmp_new = Kc.reshape(B, G, Dk, 1).permute(0, 1, 3, 2).contiguous()
1771
+ V_cmp_new = Vc.reshape(B, G, Dv, 1).permute(0, 1, 3, 2).contiguous()
1772
+ return K_cmp_new, V_cmp_new
1773
+
1774
+ def _sdpa_over_ranges(
1775
+ self,
1776
+ Q: torch.Tensor,
1777
+ K: torch.Tensor,
1778
+ V: torch.Tensor,
1779
+ ranges: torch.Tensor,
1780
+ ) -> torch.Tensor:
1781
+ """
1782
+ SDPA over concatenated gathered tokens per (B,G) according to `ranges`.
1783
+
1784
+ Args:
1785
+ Q: [B,G,h,Dk]
1786
+ K: [B,G,S_kv,Dk]
1787
+ V: [B,G,S_kv,Dv]
1788
+ ranges: [B,G,n,2] start/end pairs
1789
+ Returns:
1790
+ [B,G,h,Dv]
1791
+ """
1792
+ # Concatenate gathered tokens per (B,G)
1793
+ B, G, h, Dk = Q.shape
1794
+ Dv = V.shape[-1]
1795
+ outs = []
1796
+ S_kv = K.shape[2]
1797
+ strict_asserts = (
1798
+ self._env_cache.get("strict_asserts", False) if hasattr(self, "_env_cache") else False
1799
+ )
1800
+ for b in range(B):
1801
+ row = []
1802
+ for g in range(G):
1803
+ # Clamp and validate ranges to avoid invalid or oversized indices
1804
+ r = ranges[b, g].to(dtype=torch.int64, device=K.device) # [n,2]
1805
+ if r.numel() == 0:
1806
+ valid_pairs = torch.empty((0, 2), dtype=torch.int64, device=K.device)
1807
+ else:
1808
+ s = r[:, 0].clamp_(0, S_kv)
1809
+ e = r[:, 1].clamp_(0, S_kv)
1810
+ valid = e > s
1811
+ valid_pairs = torch.stack([s[valid], e[valid]], dim=-1)
1812
+
1813
+ # M8: Assert bounds for gathered ranges (GPU-sync gated)
1814
+ if strict_asserts and valid_pairs.numel() > 0:
1815
+ max_end = valid_pairs[:, 1].max().item()
1816
+ assert max_end <= S_kv, (
1817
+ f"Selection range exceeds sequence length: max_end={max_end} > S_kv={S_kv} "
1818
+ f"at batch={b}, group={g}."
1819
+ )
1820
+ # Build a boolean mask over S_kv to gather selected tokens (limits worst-case size)
1821
+ if valid_pairs.numel() > 0:
1822
+ m = torch.zeros((S_kv,), dtype=torch.bool, device=K.device)
1823
+ for s_e in valid_pairs:
1824
+ s_i = int(s_e[0].item())
1825
+ e_i = int(s_e[1].item())
1826
+ if e_i > s_i:
1827
+ m[s_i:e_i] = True
1828
+ idx = m.nonzero(as_tuple=False).squeeze(-1)
1829
+ else:
1830
+ idx = torch.empty((0,), dtype=torch.int64, device=K.device)
1831
+ k = (
1832
+ K[b, g, idx]
1833
+ if idx.numel() > 0
1834
+ else torch.zeros((1, Dk), device=K.device, dtype=K.dtype)
1835
+ )
1836
+ v = (
1837
+ V[b, g, idx]
1838
+ if idx.numel() > 0
1839
+ else torch.zeros((1, Dv), device=K.device, dtype=V.dtype)
1840
+ )
1841
+ q = Q[b, g] # [h,Dk]
1842
+ attn = F.scaled_dot_product_attention(
1843
+ q.unsqueeze(0).contiguous(),
1844
+ k.unsqueeze(0).contiguous(),
1845
+ v.unsqueeze(0).contiguous(),
1846
+ is_causal=True,
1847
+ )
1848
+ row.append(attn.squeeze(0)) # [h,Dv]
1849
+ outs.append(torch.stack(row, dim=0)) # [G,h,Dv]
1850
+ return torch.stack(outs, dim=0) # [B,G,h,Dv]
nsa/core/packing.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import List
3
+
4
+ import torch
5
+
6
+
7
+ def compute_sliding_lengths(S: int, w: int, device: torch.device) -> torch.Tensor:
8
+ """
9
+ Return per-row window lengths for sliding attention: L_t = min(w, t+1)
10
+ Shape: [S]
11
+ """
12
+ tpos = torch.arange(S, device=device)
13
+ return (tpos + 1).clamp_max(w)
14
+
15
+
16
+ def compute_compressed_lengths(
17
+ S: int, l: int, d: int, S_cmp: int, device: torch.device
18
+ ) -> torch.Tensor:
19
+ """
20
+ Return per-row valid compressed lengths: num_cmp(t)
21
+ Shape: [S]
22
+ """
23
+ tpos = torch.arange(S, device=device)
24
+ return torch.where(tpos + 1 < l, 0, ((tpos + 1 - l) // d) + 1).clamp(min=0, max=S_cmp)
25
+
26
+
27
+ def build_length_buckets(lengths: torch.Tensor) -> List[torch.Tensor]:
28
+ """
29
+ Group row indices by identical length.
30
+ Args:
31
+ lengths: [S] int tensor
32
+ Returns:
33
+ List of index tensors, one per unique length (descending by length)
34
+ """
35
+ if lengths.numel() == 0:
36
+ return []
37
+ unique = torch.unique(lengths, sorted=True)
38
+ # sort descending so larger buckets processed first
39
+ unique = torch.flip(unique, dims=[0])
40
+ buckets: List[torch.Tensor] = []
41
+ for L in unique.tolist():
42
+ idx = torch.nonzero(lengths == int(L), as_tuple=False).flatten()
43
+ buckets.append(idx)
44
+ return buckets
45
+
46
+
47
+ def build_cu_seqlens_for_buckets(bucket_lengths: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Build cumulative sequence lengths (cu_seqlens) for varlen APIs from a vector of lengths.
50
+ Args:
51
+ bucket_lengths: [N] lengths per row in a bucket
52
+ Returns:
53
+ cu_seqlens: [N+1] with cu_seqlens[0]=0 and cu_seqlens[i+1]=sum_{j<=i} len[j]
54
+ """
55
+ if bucket_lengths.numel() == 0:
56
+ return torch.zeros((1,), dtype=torch.int32, device=bucket_lengths.device)
57
+ cs = torch.zeros((bucket_lengths.numel() + 1,), dtype=torch.int32, device=bucket_lengths.device)
58
+ cs[1:] = torch.cumsum(bucket_lengths.to(dtype=torch.int32), dim=0)
59
+ return cs
60
+
61
+
62
+ def pack_batch_by_lengths(
63
+ x: torch.Tensor, lengths: torch.Tensor
64
+ ) -> tuple[torch.Tensor, torch.Tensor]:
65
+ """
66
+ Pack a batch of padded rows into a contiguous buffer with cu_seqlens.
67
+
68
+ Args:
69
+ x: [B,S_max,D]
70
+ lengths: [B] valid lengths per row
71
+ Returns:
72
+ (packed: [sum(lengths), D], cu_seqlens: [B+1])
73
+ """
74
+ device = x.device
75
+ B, S_max, D = x.shape
76
+ assert lengths.shape[0] == B
77
+ cu = build_cu_seqlens_for_buckets(lengths.to(torch.int32))
78
+ N = int(cu[-1].item())
79
+ packed = torch.empty((N, D), dtype=x.dtype, device=device)
80
+ write = 0
81
+ for b in range(B):
82
+ L = int(lengths[b].item())
83
+ if L > 0:
84
+ packed[write : write + L] = x[b, :L]
85
+ write += L
86
+ return packed, cu
87
+
88
+
89
+ def unpack_packed_to_padded(
90
+ packed: torch.Tensor, cu_seqlens: torch.Tensor, S_max: int
91
+ ) -> tuple[torch.Tensor, torch.Tensor]:
92
+ """
93
+ Unpack a packed buffer back to padded batch and mask.
94
+
95
+ Args:
96
+ packed: [N,D]
97
+ cu_seqlens: [B+1]
98
+ S_max: target padded length
99
+ Returns:
100
+ (padded [B,S_max,D], mask [B,S_max])
101
+ """
102
+ device = packed.device
103
+ B = cu_seqlens.shape[0] - 1
104
+ D = packed.shape[-1]
105
+ padded = torch.zeros((B, S_max, D), dtype=packed.dtype, device=device)
106
+ mask = torch.zeros((B, S_max), dtype=torch.bool, device=device)
107
+ for b in range(B):
108
+ start = int(cu_seqlens[b].item())
109
+ end = int(cu_seqlens[b + 1].item())
110
+ L = end - start
111
+ if L > 0:
112
+ padded[b, :L] = packed[start:end]
113
+ mask[b, :L] = True
114
+ return padded, mask
nsa/core/rope.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+
6
+ def build_inv_freq(
7
+ dim: int, base: float = 10000.0, device: torch.device | None = None
8
+ ) -> torch.Tensor:
9
+ assert dim % 2 == 0, "RoPE requires even dimension"
10
+ half = dim // 2
11
+ idx = torch.arange(half, device=device, dtype=torch.float32)
12
+ inv_freq = base ** (-2 * idx / dim)
13
+ return inv_freq # [half]
14
+
15
+
16
+ def apply_rope(
17
+ x: torch.Tensor,
18
+ pos: torch.Tensor,
19
+ base: float = 10000.0,
20
+ *,
21
+ scale: float = 1.0,
22
+ ) -> torch.Tensor:
23
+ """
24
+ Apply rotary position embeddings along the last dimension.
25
+
26
+ x: [..., S, D] tensor with even D
27
+ pos: [S] or [..., S] integer positions
28
+ returns: same shape as x
29
+ """
30
+ D = x.shape[-1]
31
+ assert D % 2 == 0, "RoPE requires even dimension"
32
+ device = x.device
33
+ inv_freq = build_inv_freq(D, base=base, device=device) # [D/2]
34
+ # pos shape broadcasting to [..., S, D/2]
35
+ while pos.dim() < x.dim() - 1:
36
+ pos = pos.unsqueeze(0)
37
+ # Simple NTK/YARN-style extension via position scaling: effective_pos = pos / scale
38
+ if scale <= 0:
39
+ scale = 1.0
40
+ # Compute angles in float32 for accuracy, then cast sin/cos to input dtype to preserve dtype end-to-end
41
+ angles = (pos.to(torch.float32) / float(scale)).unsqueeze(
42
+ -1
43
+ ) * inv_freq # [..., S, D/2] (float32)
44
+ sin = torch.sin(angles).to(dtype=x.dtype)
45
+ cos = torch.cos(angles).to(dtype=x.dtype)
46
+ x_2 = x.view(*x.shape[:-1], D // 2, 2)
47
+ x0, x1 = x_2[..., 0], x_2[..., 1]
48
+ y0 = x0 * cos - x1 * sin
49
+ y1 = x0 * sin + x1 * cos
50
+ y = torch.stack((y0, y1), dim=-1).view_as(x)
51
+ return y
nsa/core/selection_scorer.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from .block_index import BlockMeta
9
+
10
+
11
+ def compute_pcmp(Q: torch.Tensor, K_cmp: torch.Tensor, scale: float) -> torch.Tensor:
12
+ # Q: [G,h,Dk]; K_cmp: [B,G,S_cmp,Dk] with implicit B=1 for this path
13
+ if Q.dim() == 3:
14
+ # Q: [G,h,Dk]; K_cmp: [1,G,S_cmp,Dk] (implicit B=1)
15
+ G, h, Dk = Q.shape
16
+ S_cmp = K_cmp.shape[2]
17
+ q = Q.reshape(G * h, 1, Dk)
18
+ # Expand K over heads without materializing copies
19
+ k = (
20
+ K_cmp[0]
21
+ .unsqueeze(1) # [G,1,S_cmp,Dk]
22
+ .expand(G, h, S_cmp, Dk)
23
+ .reshape(G * h, S_cmp, Dk)
24
+ )
25
+ logits = torch.bmm(q, k.transpose(1, 2)).squeeze(1) * scale
26
+ return F.softmax(logits, dim=-1).reshape(1, G, h, S_cmp)
27
+ else:
28
+ # Q: [B,G,h,Dk]; K_cmp: [B,G,S_cmp,Dk]
29
+ B, G, h, Dk = Q.shape
30
+ S_cmp = K_cmp.shape[2]
31
+ q = Q.reshape(B * G * h, 1, Dk)
32
+ # Expand K over heads without materializing copies
33
+ k = (
34
+ K_cmp.unsqueeze(2) # [B,G,1,S_cmp,Dk]
35
+ .expand(B, G, h, S_cmp, Dk)
36
+ .reshape(B * G * h, S_cmp, Dk)
37
+ )
38
+ logits = torch.bmm(q, k.transpose(1, 2)).squeeze(1) * scale
39
+ p = F.softmax(logits, dim=-1)
40
+ return p.reshape(B, G, h, S_cmp)
41
+
42
+
43
+ def compute_pcmp_all(Q_all: torch.Tensor, K_cmp: torch.Tensor, scale: float) -> torch.Tensor:
44
+ """
45
+ Q_all: [B,S,G,h,Dk], K_cmp: [B,G,S_cmp,Dk] -> p_cmp_all: [B,S,G,h,S_cmp]
46
+ """
47
+ use_mixed = os.getenv("NSA_P_CMP_MIXED", "0").lower() in ("1", "true", "yes", "on")
48
+ if use_mixed and Q_all.device.type == "cuda":
49
+ # Optional mixed-precision path (disabled by default). Computes logits and softmax
50
+ # under autocast to reduce memory bandwidth on large shapes. Output is upcast
51
+ # back to the original dtype to preserve downstream numerics.
52
+ orig_dtype = Q_all.dtype
53
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
54
+ Kt = K_cmp.permute(0, 1, 3, 2) # [B,G,Dk,S_cmp]
55
+ logits = torch.einsum("bsghd,bgdc->bsghc", Q_all, Kt) * scale
56
+ p = F.softmax(logits, dim=-1)
57
+ return p.to(orig_dtype)
58
+ else:
59
+ # Baseline precise path
60
+ Kt = K_cmp.permute(0, 1, 3, 2) # [B,G,Dk,S_cmp]
61
+ logits = torch.einsum("bsghd,bgdc->bsghc", Q_all, Kt) * scale
62
+ return F.softmax(logits, dim=-1)
63
+
64
+
65
+ def map_pcmp_to_pslc(p_cmp: torch.Tensor, meta: BlockMeta) -> torch.Tensor:
66
+ # p_cmp: [B,G,h,S_cmp]
67
+ B, G, h, S_cmp = p_cmp.shape
68
+ indptr = meta.M_csl_indptr
69
+ indices = meta.M_csl_indices
70
+ values = meta.M_csl_values
71
+ S_sel = meta.sel_starts.numel()
72
+ device = p_cmp.device
73
+ # Out-of-place accumulation to avoid in-place versioning issues under GC/DDP
74
+ p_slc = torch.zeros((B, G, h, S_sel), device=device, dtype=p_cmp.dtype)
75
+ acc = torch.zeros_like(p_slc)
76
+ # CSR row-wise multiply-add
77
+ for r in range(S_cmp):
78
+ start, end = int(indptr[r].item()), int(indptr[r + 1].item())
79
+ if start == end:
80
+ continue
81
+ cols = indices[start:end].to(device)
82
+ w = values[start:end].to(device=device, dtype=p_cmp.dtype) # [nnz_r]
83
+ contrib = p_cmp[..., r].unsqueeze(-1) * w # [B,G,h,nnz_r]
84
+ # Ensure Long dtype for scatter_add indices
85
+ idx = cols.view(1, 1, 1, -1).expand(B, G, h, -1).long()
86
+ acc = acc.scatter_add(-1, idx, contrib)
87
+ return acc
88
+
89
+
90
+ def map_pcmp_to_pslc_batched(p_cmp_all: torch.Tensor, meta: BlockMeta) -> torch.Tensor:
91
+ """
92
+ p_cmp_all: [B,S,G,h,S_cmp] -> p_slc_all: [B,S,G,h,S_sel]
93
+ Vectorized over B,S,G,h while looping CSR rows over S_cmp.
94
+ """
95
+ B, S, G, h, S_cmp = p_cmp_all.shape
96
+ device = p_cmp_all.device
97
+ S_sel = meta.sel_starts.numel()
98
+ if S_cmp == 0:
99
+ return torch.zeros((B, S, G, h, S_sel), device=device, dtype=p_cmp_all.dtype)
100
+ # COO sparse matmul: for each nnz (r,c,w), add p_cmp[..., r]*w to p_slc[..., c]
101
+ rows, cols = meta.M_csl_coo_indices.to(device)
102
+ w = meta.M_csl_coo_values.to(device=device, dtype=p_cmp_all.dtype)
103
+ # Filter mapping rows to those < current S_cmp to avoid out-of-bounds in early decode
104
+ valid_mask = rows < S_cmp
105
+ if valid_mask.dim() == 0:
106
+ valid_mask = valid_mask.unsqueeze(0)
107
+ rows = rows[valid_mask]
108
+ cols = cols[valid_mask]
109
+ w = w[valid_mask]
110
+ if rows.numel() == 0:
111
+ return torch.zeros((B, S, G, h, S_sel), device=device, dtype=p_cmp_all.dtype)
112
+ p_src = p_cmp_all[..., rows] * w # [B,S,G,h,nnz]
113
+ p_slc = torch.zeros((B, S, G, h, S_sel), device=device, dtype=p_cmp_all.dtype)
114
+ # Ensure Long dtype for scatter_add indices
115
+ idx = cols.view(1, 1, 1, 1, -1).expand(B, S, G, h, -1).long()
116
+ p_slc = p_slc.scatter_add(-1, idx, p_src)
117
+ return p_slc
118
+
119
+
120
+ def group_reduce_pslc(p_slc: torch.Tensor) -> torch.Tensor:
121
+ # Sum across heads in group (Eq. 10)
122
+ return p_slc.sum(dim=2)
123
+
124
+
125
+ def select_topn_ranges(
126
+ p_grp: torch.Tensor,
127
+ meta: BlockMeta,
128
+ n_top: int,
129
+ t_token: int,
130
+ force_init: bool = True,
131
+ force_local: int = 2,
132
+ _skip_validation: bool = False,
133
+ ) -> torch.Tensor:
134
+ """Select top-n block ranges with deterministic tie-breaking.
135
+
136
+ M8: Enhanced with robust deterministic tie-breaking for training reproducibility.
137
+ Uses scaled epsilon bias to prefer lower indices on score ties, ensuring
138
+ identical selection across runs with the same inputs.
139
+
140
+ Args:
141
+ p_grp: Group probabilities [B,G,S_sel]
142
+ meta: Block metadata with selection ranges
143
+ n_top: Number of top blocks to select
144
+ t_token: Current token position (0-indexed)
145
+ force_init: Whether to force include block 0
146
+ force_local: Number of local blocks to force include
147
+
148
+ Returns:
149
+ Selected ranges [B,G,n_top,2] as [start,end) pairs
150
+ """
151
+ # p_grp: [B,G,S_sel]
152
+ B, G, S_sel = p_grp.shape
153
+ device = p_grp.device
154
+ # Determine candidate blocks ≤ t
155
+ sel_starts = meta.sel_starts.to(device)
156
+ # mask future blocks
157
+ valid = sel_starts + meta.l_sel - 1 <= t_token
158
+ masked = p_grp.masked_fill(~valid.view(1, 1, -1), float("-inf"))
159
+ # force-includes set
160
+ forced_list = []
161
+ if force_init:
162
+ forced_list.append(torch.zeros((B, G), dtype=torch.int64, device=device))
163
+ if force_local > 0:
164
+ last_block = torch.clamp((torch.tensor(t_token, device=device) // meta.l_sel), min=0)
165
+ for i in range(force_local):
166
+ forced_list.append(torch.clamp(last_block - i, min=0).expand(B, G))
167
+ forced_idx = (
168
+ torch.stack(forced_list, dim=-1)
169
+ if forced_list
170
+ else torch.empty((B, G, 0), device=device, dtype=torch.int64)
171
+ )
172
+ # Exclude forced from top-k candidates by setting their scores to -inf
173
+ if forced_idx.numel() > 0:
174
+ forced_mask = torch.zeros_like(masked, dtype=torch.bool)
175
+ forced_mask.scatter_(-1, forced_idx, True)
176
+ masked = masked.masked_fill(forced_mask, float("-inf"))
177
+ # pick remaining to fill up to n_top
178
+ k_rest = torch.clamp(torch.tensor(n_top - forced_idx.shape[-1], device=device), min=0).item()
179
+ if k_rest > 0:
180
+ # M8: Deterministic tie-breaker - prefer lower indices for reproducible selection
181
+ # Use a tiny, fixed bias in float32 space to avoid overwhelming scores in low-precision
182
+ # dtypes (e.g., bf16/FP16). We perform ranking in float32 regardless of input dtype.
183
+ tie_break_scale = torch.tensor(1e-8, device=device, dtype=torch.float32)
184
+ base_idx = torch.arange(S_sel, device=device, dtype=torch.float32).view(1, 1, S_sel)
185
+ composite = masked.to(torch.float32) - (base_idx * tie_break_scale)
186
+ # Ensure deterministic topk with sorted=True for consistent ordering
187
+ k_actual = min(k_rest, S_sel)
188
+ _, top_idx = torch.topk(composite, k=k_actual, dim=-1, largest=True, sorted=True)
189
+
190
+ # M8: Assert tie-breaking worked - check for potential numerical issues
191
+ if torch.is_grad_enabled():
192
+ # Only check during training when gradients are enabled
193
+ with torch.no_grad():
194
+ orig_scores = torch.gather(masked, -1, top_idx).to(torch.float32)
195
+ if orig_scores.numel() > 1:
196
+ # Check if adjacent scores are suspiciously close (potential tie-break failure)
197
+ score_diffs = torch.diff(orig_scores, dim=-1)
198
+ very_close = torch.abs(score_diffs) < (float(tie_break_scale.item()) * 0.1)
199
+ if very_close.any():
200
+ from nsa.core.debug import log
201
+
202
+ log(
203
+ "warn.selection_tiebreak",
204
+ msg="Close scores detected in selection - potential tie-break instability",
205
+ min_diff=float(torch.abs(score_diffs).min().item()),
206
+ tie_break_scale=float(tie_break_scale),
207
+ )
208
+ sel_idx = torch.cat([forced_idx, top_idx], dim=-1)
209
+ else:
210
+ sel_idx = forced_idx
211
+ # sort selected indices ascending for consistent range merging
212
+ sel_idx = torch.sort(sel_idx, dim=-1).values
213
+
214
+ # M8: Optional determinism validation (skip if called from validation itself)
215
+ if not _skip_validation and os.getenv("NSA_VALIDATE_SELECTION_DETERMINISM", "0").lower() in (
216
+ "1",
217
+ "true",
218
+ "yes",
219
+ ):
220
+ validate_selection_determinism(p_grp, meta, n_top, t_token)
221
+ # merge adjacent into contiguous ranges
222
+ ranges = []
223
+ for b in range(B):
224
+ bg = []
225
+ for g in range(G):
226
+ blocks = sel_starts[sel_idx[b, g]] # [k], sorted non-decreasing
227
+ # Deduplicate without extra sort (faster on GPU for small k)
228
+ blocks = torch.unique_consecutive(blocks)
229
+ if blocks.numel() == 0:
230
+ bg.append(torch.zeros((n_top, 2), dtype=torch.int32, device=device))
231
+ continue
232
+ cur_s = int(blocks[0].item())
233
+ cur_e = cur_s + meta.l_sel
234
+ merged: List[Tuple[int, int]] = []
235
+ for x in blocks[1:].tolist():
236
+ if x == cur_e: # adjacent
237
+ cur_e += meta.l_sel
238
+ else:
239
+ merged.append((cur_s, cur_e))
240
+ cur_s, cur_e = x, x + meta.l_sel
241
+ merged.append((cur_s, cur_e))
242
+ # pad/truncate to n_top
243
+ out = torch.zeros((n_top, 2), dtype=torch.int32, device=device)
244
+ for i, (s, e) in enumerate(merged[:n_top]):
245
+ e = min(e, t_token + 1)
246
+ out[i, 0] = s
247
+ out[i, 1] = e
248
+ bg.append(out)
249
+ ranges.append(torch.stack(bg, dim=0))
250
+ return torch.stack(ranges, dim=0) # [B,G,n_top,2]
251
+
252
+
253
+ # ===== Batched selection (prefill fast path) =====
254
+
255
+
256
+ def select_topn_ranges_batched(
257
+ p_grp_all: torch.Tensor, # [B,S,G,S_sel]
258
+ meta: BlockMeta,
259
+ n_top: int,
260
+ S: int,
261
+ force_init: bool = True,
262
+ force_local: int = 2,
263
+ ) -> torch.Tensor: # [B,S,G,n_ranges,2]
264
+ """
265
+ M8: Deterministic batched selection with enhanced tie-breaking:
266
+ - Mask future blocks per position t via block end ≤ t+1
267
+ - Force include block 0 and last k local blocks (dedup)
268
+ - Exclude forced from scored top‑k
269
+ - Robust deterministic tie‑break to lower index on equal scores
270
+ - Convert to merged contiguous [start,end) ranges clamped to ≤ t+1
271
+ - Validation hooks for training reproducibility
272
+ """
273
+ B, S_q, G, S_sel = p_grp_all.shape
274
+ device = p_grp_all.device
275
+
276
+ sel_starts = meta.sel_starts.to(device)
277
+ sel_ends = sel_starts + meta.l_sel
278
+ tpos = torch.arange(S, device=device).view(S, 1)
279
+ valid = sel_ends.view(1, -1) <= (tpos + 1) # [S,S_sel]
280
+ disallowed = ~valid
281
+ masked = p_grp_all.masked_fill(disallowed.view(1, S, 1, S_sel), float("-inf"))
282
+
283
+ # Forced blocks (dedup across 0 and locals)
284
+ forced_list = []
285
+ if force_init:
286
+ forced_list.append(torch.zeros((B, S, G, 1), dtype=torch.long, device=device))
287
+ if force_local > 0:
288
+ tpos1 = torch.arange(S, device=device)
289
+ last_block = (tpos1 // meta.l_sel).clamp_min(0)
290
+ for k in range(force_local):
291
+ idx = (last_block - k).clamp_min(0).view(1, S, 1, 1).expand(B, S, G, 1)
292
+ forced_list.append(idx)
293
+ forced = (
294
+ torch.cat(forced_list, dim=-1)
295
+ if forced_list
296
+ else torch.empty((B, S, G, 0), dtype=torch.long, device=device)
297
+ )
298
+ if forced.numel() > 0:
299
+ # Ensure ascending per trailing dim then drop duplicates consecutively
300
+ forced = torch.sort(forced, dim=-1).values
301
+ forced = torch.unique_consecutive(forced, dim=-1)
302
+
303
+ if forced.numel() > 0:
304
+ forced_mask = torch.zeros_like(masked, dtype=torch.bool)
305
+ forced_mask.scatter_(-1, forced, True)
306
+ masked = masked.masked_fill(forced_mask, float("-inf"))
307
+
308
+ # Deterministic top‑k using composite key with tiny index bias
309
+ k_rest = max(0, n_top - forced.shape[-1])
310
+ if k_rest > 0:
311
+ # M8: Deterministic tie-breaker - prefer lower indices; rank in float32 to avoid
312
+ # overwhelming biases under low-precision dtypes.
313
+ tie_break_scale = torch.tensor(1e-8, device=device, dtype=torch.float32)
314
+ base_idx = (
315
+ torch.arange(S_sel, device=device, dtype=torch.float32)
316
+ .view(1, 1, 1, S_sel)
317
+ .expand(B, S, G, S_sel)
318
+ )
319
+ composite = masked.to(torch.float32) - (base_idx * tie_break_scale)
320
+ # Ensure deterministic topk with explicit sorted=True for batched path
321
+ k_actual = min(k_rest, S_sel)
322
+ _, top_idx = torch.topk(composite, k=k_actual, dim=-1, largest=True, sorted=True)
323
+
324
+ # M8: Optional validation for tie-breaking effectiveness in training
325
+ if torch.is_grad_enabled() and k_actual > 1:
326
+ with torch.no_grad():
327
+ orig_scores = torch.gather(masked, -1, top_idx).to(torch.float32)
328
+ # Check last dimension for potential tie-break issues
329
+ score_diffs = torch.diff(orig_scores, dim=-1)
330
+ very_close = torch.abs(score_diffs) < (float(tie_break_scale.item()) * 0.1)
331
+ if very_close.any():
332
+ from nsa.core.debug import log
333
+
334
+ log(
335
+ "warn.batched_selection_tiebreak",
336
+ msg="Close scores in batched selection - potential instability",
337
+ batch_close_count=int(very_close.sum().item()),
338
+ tie_break_scale=float(tie_break_scale),
339
+ )
340
+ selected = torch.cat([forced, top_idx], dim=-1)
341
+ else:
342
+ selected = forced[..., :n_top]
343
+
344
+ # Keep only valid (≤ t) indices; drop disallowed fill-ins
345
+ valid_full = valid.view(1, S, 1, S_sel).expand(B, S, G, S_sel)
346
+ is_valid_pick = torch.gather(valid_full, -1, selected)
347
+ # Replace invalid with -1 sentinel
348
+ selected = torch.where(is_valid_pick, selected, torch.full_like(selected, -1))
349
+ # Special-case: if requested n_top ≥ number of valid blocks at t, select exactly all valid blocks [0..t]
350
+ num_valid = valid.sum(dim=1) # [S]
351
+ # Build ascending [0..S_sel-1] to pick prefix per t
352
+ all_idx = torch.arange(S_sel, device=device).view(1, 1, 1, S_sel).expand(B, S, G, S_sel)
353
+ pick_mask = all_idx < num_valid.view(1, S, 1, 1)
354
+ if n_top >= S_sel:
355
+ selected = torch.where(pick_mask, all_idx, torch.full_like(all_idx, -1))
356
+ selected = torch.sort(selected, dim=-1).values
357
+ # Env-gated GPU range conversion (v2) to remove Python loops on hot path
358
+ use_v2 = os.getenv("NSA_SEL_RANGES_V2", "1").lower() in ("1", "true", "yes")
359
+ if use_v2:
360
+ ranges = convert_indices_to_ranges_batched_v2(selected, meta, S)
361
+ else:
362
+ ranges = convert_indices_to_ranges_batched(selected, meta, S)
363
+ return ranges
364
+
365
+
366
+ def convert_indices_to_ranges_batched_dispatch(
367
+ indices: torch.Tensor,
368
+ meta: BlockMeta,
369
+ S: int,
370
+ ) -> torch.Tensor:
371
+ """
372
+ Dispatch helper mirroring production behavior: chooses v2 by default unless disabled.
373
+ Exposed for tests and tooling.
374
+ """
375
+ use_v2 = os.getenv("NSA_SEL_RANGES_V2", "1").lower() in ("1", "true", "yes")
376
+ if use_v2:
377
+ return convert_indices_to_ranges_batched_v2(indices, meta, S)
378
+ return convert_indices_to_ranges_batched(indices, meta, S)
379
+
380
+
381
+ def convert_indices_to_ranges_batched(
382
+ indices: torch.Tensor, # [B,S,G,k]
383
+ meta: BlockMeta,
384
+ S: int,
385
+ ) -> torch.Tensor: # [B,S,G,n_max,2]
386
+ B, S_q, G, k = indices.shape
387
+ device = indices.device
388
+ sel_starts = meta.sel_starts.to(device)
389
+
390
+ all_ranges = []
391
+ for b in range(B):
392
+ for t in range(S_q):
393
+ clamp_end = int(t) + 1
394
+ for g in range(G):
395
+ block_ids = [int(x) for x in indices[b, t, g].tolist() if int(x) >= 0]
396
+ spans = []
397
+ last_s, last_e = None, None
398
+ prev = None
399
+ for bid in block_ids:
400
+ # Skip invalid/out-of-range indices defensively
401
+ if bid < 0 or bid >= sel_starts.numel():
402
+ continue
403
+ if prev is not None and bid == prev:
404
+ continue
405
+ prev = bid
406
+ s0 = int(sel_starts[bid].item())
407
+ e0 = min(s0 + meta.l_sel, clamp_end)
408
+ if e0 <= s0:
409
+ continue
410
+ if last_s is None:
411
+ last_s, last_e = s0, e0
412
+ elif s0 == last_e:
413
+ last_e = e0
414
+ else:
415
+ spans.append((last_s, last_e))
416
+ last_s, last_e = s0, e0
417
+ if last_s is not None:
418
+ spans.append((last_s, last_e))
419
+ all_ranges.append(spans)
420
+
421
+ max_ranges = max((len(r) for r in all_ranges), default=0)
422
+ out = torch.zeros((B, S_q, G, max_ranges, 2), dtype=torch.int32, device=device)
423
+ idx = 0
424
+ for b in range(B):
425
+ for t in range(S_q):
426
+ for g in range(G):
427
+ spans = all_ranges[idx]
428
+ for i, (s0, e0) in enumerate(spans):
429
+ out[b, t, g, i, 0] = s0
430
+ out[b, t, g, i, 1] = e0
431
+ idx += 1
432
+ return out
433
+
434
+
435
+ def convert_indices_to_ranges_batched_v2(
436
+ indices: torch.Tensor, # [B,S,G,k], sorted asc, -1 padded
437
+ meta: BlockMeta,
438
+ S: int,
439
+ ) -> torch.Tensor: # [B,S,G,k,2] (padded with zero-length ranges)
440
+ """
441
+ Vectorized GPU range conversion with no Python loops.
442
+ - Treat equal and +1 successive block ids as a single merged run.
443
+ - Map runs to token [start, end) using sel_starts and l_sel.
444
+ - Clamp end to t+1 per row to preserve causality.
445
+ - Output is padded to k runs per row; zero-length ranges are encoded as [0,0].
446
+ """
447
+ # NVTX annotation support
448
+ _nvtx = os.getenv("NSA_NVTX", "0").lower() in ("1", "true", "yes")
449
+ if _nvtx:
450
+ try:
451
+ torch.cuda.nvtx.range_push("nsa.sel.ranges_v2")
452
+ except Exception:
453
+ _nvtx = False
454
+
455
+ device = indices.device
456
+ B, S_q, G, K = indices.shape
457
+ if K == 0:
458
+ return torch.zeros((B, S_q, G, 0, 2), dtype=torch.int32, device=device)
459
+
460
+ # Valid mask and prepared index tensor
461
+ if _nvtx:
462
+ try:
463
+ torch.cuda.nvtx.range_push("v2_run_detection")
464
+ except Exception:
465
+ pass
466
+
467
+ valid = indices.ge(0)
468
+ x = torch.where(valid, indices, torch.full_like(indices, -2)) # sentinel -2
469
+
470
+ # Identify run starts: first valid element or break in adjacency (including dedup collapse)
471
+ x_shift = torch.cat([torch.full_like(x[..., :1], -2), x[..., :-1]], dim=-1)
472
+ prev_valid = x_shift.ge(0)
473
+ diff = x - x_shift
474
+ adjacent_or_dup = (diff.eq(1) | diff.eq(0)) & prev_valid
475
+ run_start = valid & (~adjacent_or_dup | (~prev_valid))
476
+
477
+ if _nvtx:
478
+ try:
479
+ torch.cuda.nvtx.range_pop()
480
+ except Exception:
481
+ pass
482
+
483
+ # Row-local run ids [0..runs_per_row-1], -1 for invalid
484
+ run_id = run_start.to(torch.int32).cumsum(dim=-1) - 1
485
+ run_id = torch.where(valid, run_id, torch.full_like(run_id, -1))
486
+
487
+ # Number of runs per row and flattened row indexing
488
+ runs_per_row = run_start.sum(dim=-1, dtype=torch.int32) # [B,S,G]
489
+ N = B * S_q * G
490
+ runs_per_row_flat = runs_per_row.reshape(N)
491
+
492
+ # Build flattened per-run metadata
493
+ # Flatten last dim for selection
494
+ run_start_flat = run_start.reshape(-1, K)
495
+ x_flat = x.reshape(-1, K)
496
+ run_id_flat = run_id.reshape(-1, K)
497
+
498
+ # Indices (within last dim) where runs start per row
499
+ pos = torch.arange(K, device=device, dtype=torch.int32)
500
+ pos_flat = pos.view(1, K).expand(run_start_flat.shape[0], K)
501
+ start_pos_flat = pos_flat[run_start_flat]
502
+ # Corresponding block ids where runs start
503
+ start_blk_flat = x_flat[run_start_flat].to(torch.int32)
504
+
505
+ # Build unique global run ids by offsetting row-local run ids with row offsets
506
+ run_offsets = torch.cumsum(torch.nn.functional.pad(runs_per_row_flat, (1, 0)), dim=0)[
507
+ :-1
508
+ ] # [N]
509
+ # Row index per element (0..N-1)
510
+ row_ids = torch.arange(N, device=device, dtype=torch.int32)
511
+ row_ids_per_elem = row_ids.view(N, 1).expand(N, K)
512
+ # Global run id per element; -1 for invalid
513
+ global_rid = torch.where(
514
+ run_id_flat.ge(0),
515
+ run_id_flat + run_offsets.view(N, 1),
516
+ torch.full_like(run_id_flat, -1),
517
+ )
518
+ global_rid_valid = global_rid[run_id_flat.ge(0)] # [total_valid_elems]
519
+
520
+ # For each global run, compute max block id in that run (end block)
521
+ if _nvtx:
522
+ try:
523
+ torch.cuda.nvtx.range_push("v2_scatter_reduce")
524
+ except Exception:
525
+ pass
526
+
527
+ total_runs = int(runs_per_row_flat.sum().item())
528
+ if total_runs == 0:
529
+ if _nvtx:
530
+ try:
531
+ torch.cuda.nvtx.range_pop()
532
+ except Exception:
533
+ pass
534
+ return torch.zeros((B, S_q, G, K, 2), dtype=torch.int32, device=device)
535
+ max_blk = torch.full((total_runs,), -2, dtype=torch.int32, device=device)
536
+ # Values to reduce are block ids for valid elements
537
+ blk_vals = x_flat[run_id_flat.ge(0)].to(torch.int32)
538
+ max_blk.scatter_reduce_(
539
+ 0, global_rid_valid.to(torch.int64), blk_vals, reduce="amax", include_self=False
540
+ )
541
+
542
+ if _nvtx:
543
+ try:
544
+ torch.cuda.nvtx.range_pop()
545
+ except Exception:
546
+ pass
547
+
548
+ # Start block ids per run, collected in row order
549
+ start_blk_per_run = start_blk_flat # length == total_runs
550
+
551
+ # Map block ids to token starts/ends (guard invalid/out-of-range)
552
+ sel_starts = meta.sel_starts.to(device=device, dtype=torch.int32)
553
+ S_sel = int(sel_starts.numel())
554
+ l_sel = int(meta.l_sel)
555
+ valid_runs = (
556
+ (start_blk_per_run >= 0)
557
+ & (start_blk_per_run < S_sel)
558
+ & (max_blk >= 0)
559
+ & (max_blk < S_sel)
560
+ )
561
+ # Default zeros; fill only valid runs
562
+ start_tok_flat = torch.zeros_like(start_blk_per_run, dtype=torch.int32, device=device)
563
+ end_tok_flat = torch.zeros_like(max_blk, dtype=torch.int32, device=device)
564
+ if valid_runs.any():
565
+ start_tok_flat[valid_runs] = sel_starts[start_blk_per_run[valid_runs]]
566
+ end_tok_flat[valid_runs] = sel_starts[max_blk[valid_runs]] + l_sel
567
+
568
+ # Clamp end to t+1 per row (only meaningful for valid runs)
569
+ # Row t positions: [S] repeated over B,G
570
+ tpos = torch.arange(S, device=device, dtype=torch.int32)
571
+ t_rows = tpos.view(1, S, 1).expand(B, S, G).reshape(N) # [N]
572
+ # t per run: repeat per row by runs_per_row
573
+ t_per_run = torch.repeat_interleave(t_rows, runs_per_row_flat)
574
+ end_tok_flat = torch.minimum(end_tok_flat, (t_per_run + 1))
575
+
576
+ # Prepare output [B,S,G,K,2], fill zeros then scatter first runs_per_row entries per row
577
+ out = torch.zeros((B, S_q, G, K, 2), dtype=torch.int32, device=device)
578
+ # Positions within row to write (0..K-1): take row-local run_id at run starts
579
+ run_id_at_starts = (run_id.reshape(-1, K))[run_start_flat]
580
+ # Compute base index in flattened out for each run write
581
+ # Build linear indices for advanced indexing
582
+ # Map flat run order back to (row, pos)
583
+ row_of_run = torch.repeat_interleave(row_ids, runs_per_row_flat)
584
+ pos_in_row = run_id_at_starts # 0..runs_per_row[row]-1
585
+ b = (row_of_run // (S_q * G)).to(torch.int64)
586
+ rem = row_of_run % (S_q * G)
587
+ t = (rem // G).to(torch.int64)
588
+ g = (rem % G).to(torch.int64)
589
+ p = pos_in_row.to(torch.int64)
590
+ # Scatter only valid runs
591
+ if valid_runs.any():
592
+ vr = valid_runs.to(torch.bool)
593
+ b_v = b[vr]
594
+ t_v = t[vr]
595
+ g_v = g[vr]
596
+ p_v = p[vr]
597
+ out[b_v, t_v, g_v, p_v, 0] = start_tok_flat[vr].to(torch.int32)
598
+ out[b_v, t_v, g_v, p_v, 1] = end_tok_flat[vr].to(torch.int32)
599
+
600
+ if _nvtx:
601
+ try:
602
+ torch.cuda.nvtx.range_pop()
603
+ except Exception:
604
+ pass
605
+
606
+ return out
607
+
608
+
609
+ def map_pcmp_to_pslc_slow_path(p_cmp_all: torch.Tensor, meta: BlockMeta) -> torch.Tensor:
610
+ """
611
+ M8: Eq.9 slow path verifier - explicit mathematical computation.
612
+
613
+ This function implements the exact mathematical definition by using the
614
+ CSR mapping directly instead of recomputing overlaps. This ensures it
615
+ matches the fast path exactly.
616
+
617
+ Args:
618
+ p_cmp_all: [B,S,G,h,S_cmp] compressed probabilities
619
+ meta: Block metadata with overlap mapping
620
+
621
+ Returns:
622
+ p_slc_all: [B,S,G,h,S_sel] selection probabilities
623
+ """
624
+ B, S, G, h, S_cmp = p_cmp_all.shape
625
+ device = p_cmp_all.device
626
+ S_sel = meta.sel_starts.numel()
627
+
628
+ if S_cmp == 0:
629
+ return torch.zeros((B, S, G, h, S_sel), device=device, dtype=p_cmp_all.dtype)
630
+
631
+ # Use CSR mapping directly (same as fast path but with explicit loops)
632
+ p_slc_all = torch.zeros((B, S, G, h, S_sel), device=device, dtype=p_cmp_all.dtype)
633
+
634
+ indptr = meta.M_csl_indptr.to(device)
635
+ indices = meta.M_csl_indices.to(device)
636
+ values = meta.M_csl_values.to(device, dtype=p_cmp_all.dtype)
637
+
638
+ # For each compressed block (CSR row)
639
+ for cmp_i in range(min(S_cmp, len(indptr) - 1)):
640
+ start = int(indptr[cmp_i].item())
641
+ end = int(indptr[cmp_i + 1].item())
642
+
643
+ if start == end:
644
+ continue
645
+
646
+ # Get the selection blocks this compressed block contributes to
647
+ sel_cols = indices[start:end]
648
+ weights = values[start:end]
649
+
650
+ # Add weighted contribution to each selection block
651
+ for j, (sel_idx, weight) in enumerate(zip(sel_cols, weights)):
652
+ sel_idx = int(sel_idx.item())
653
+ if sel_idx < S_sel:
654
+ p_slc_all[..., sel_idx] += p_cmp_all[..., cmp_i] * float(weight.item())
655
+
656
+ return p_slc_all
657
+
658
+
659
+ def verify_mapping_equivalence(
660
+ p_cmp_all: torch.Tensor, meta: BlockMeta, rtol: float = 1e-5, atol: float = 1e-8
661
+ ) -> tuple[bool, dict]:
662
+ """
663
+ M8: Verify fast COO path matches slow mathematical path (Eq.9 verification).
664
+
665
+ Args:
666
+ p_cmp_all: Compressed probabilities to test
667
+ meta: Block metadata
668
+ rtol: Relative tolerance for comparison
669
+ atol: Absolute tolerance for comparison
670
+
671
+ Returns:
672
+ (is_equivalent, details): True if paths match, plus diagnostic info
673
+ """
674
+ # Only run verification if explicitly requested via env flag
675
+ if os.getenv("NSA_VERIFY_EQ9_MAPPING", "0").lower() not in ("1", "true", "yes"):
676
+ return True, {"status": "skipped", "reason": "NSA_VERIFY_EQ9_MAPPING not set"}
677
+
678
+ with torch.no_grad():
679
+ # Compute both paths
680
+ fast_result = map_pcmp_to_pslc_batched(p_cmp_all, meta)
681
+ slow_result = map_pcmp_to_pslc_slow_path(p_cmp_all, meta)
682
+
683
+ # Compare results
684
+ is_close = torch.allclose(fast_result, slow_result, rtol=rtol, atol=atol)
685
+
686
+ # Compute diagnostic metrics
687
+ abs_diff = (fast_result - slow_result).abs()
688
+ max_abs_diff = abs_diff.max().item()
689
+ mean_abs_diff = abs_diff.mean().item()
690
+ rel_diff = abs_diff / (slow_result.abs() + atol)
691
+ max_rel_diff = rel_diff.max().item()
692
+
693
+ details = {
694
+ "status": "verified" if is_close else "mismatch",
695
+ "max_abs_diff": max_abs_diff,
696
+ "mean_abs_diff": mean_abs_diff,
697
+ "max_rel_diff": max_rel_diff,
698
+ "shape": list(p_cmp_all.shape),
699
+ "rtol": rtol,
700
+ "atol": atol,
701
+ }
702
+
703
+ if not is_close:
704
+ from nsa.core.debug import log
705
+
706
+ log(
707
+ "error.eq9_mapping_mismatch",
708
+ msg="Fast COO path does not match slow mathematical path",
709
+ **details,
710
+ )
711
+
712
+ return is_close, details
713
+
714
+
715
+ def validate_selection_determinism(
716
+ p_grp: torch.Tensor, meta: BlockMeta, n_top: int, t_token: int, num_trials: int = 5
717
+ ) -> bool:
718
+ """Validate that selection is deterministic by running multiple times.
719
+
720
+ Args:
721
+ p_grp: Group probabilities [B,G,S_sel]
722
+ meta: Block metadata
723
+ n_top: Number of top blocks to select
724
+ t_token: Current token position
725
+ num_trials: Number of trials to test determinism
726
+
727
+ Returns:
728
+ True if all trials produce identical results
729
+ """
730
+ # Only run validation if explicitly requested via env flag
731
+ if os.getenv("NSA_VALIDATE_SELECTION_DETERMINISM", "0").lower() not in ("1", "true", "yes"):
732
+ return True
733
+
734
+ if p_grp.requires_grad:
735
+ # Don't validate during training to avoid affecting gradients
736
+ return True
737
+
738
+ with torch.no_grad():
739
+ results = []
740
+ for trial in range(num_trials):
741
+ ranges = select_topn_ranges(
742
+ p_grp.clone(), meta, n_top, t_token, True, 2, _skip_validation=True
743
+ )
744
+ results.append(ranges.clone())
745
+
746
+ # Check if all results are identical
747
+ for i in range(1, num_trials):
748
+ if not torch.equal(results[0], results[i]):
749
+ from nsa.core.debug import log
750
+
751
+ log(
752
+ "error.selection_nondeterministic",
753
+ msg=f"Selection non-deterministic: trial 0 != trial {i}",
754
+ trial_0_shape=list(results[0].shape),
755
+ trial_i_shape=list(results[i].shape),
756
+ )
757
+ return False
758
+
759
+ return True
nsa/data_pipeline.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ #!/usr/bin/env python3
3
+ """Data pipeline utilities for streaming and local datasets.
4
+
5
+ Provides a FineWeb-Edu IterableDataset and simple local JSONL/TXT loaders.
6
+ This module is optional; scripts/train_showcase.py currently uses a simpler
7
+ loader in scripts/datasets. Migrate incrementally as needed.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import os
14
+ from dataclasses import dataclass
15
+ from typing import Callable, Iterable, Iterator, List, Optional
16
+
17
+ Tokenizer = Callable[[str], List[int]]
18
+
19
+
20
+ @dataclass
21
+ class Shard:
22
+ mod: int = 1
23
+ rem: int = 0
24
+
25
+
26
+ def fineweb_stream_batches(
27
+ encode: Tokenizer,
28
+ seq_len: int,
29
+ batch_size: int,
30
+ shard: Shard = Shard(),
31
+ report_docs: int = 1000,
32
+ ) -> Iterator[List[List[int]]]:
33
+ try:
34
+ from datasets import Features, Value, load_dataset # type: ignore
35
+ except Exception as e:
36
+ raise RuntimeError("datasets package required. Install with: pip install datasets") from e
37
+
38
+ features = Features(
39
+ {
40
+ "text": Value("string"),
41
+ "id": Value("string"),
42
+ "dump": Value("string"),
43
+ "url": Value("string"),
44
+ "file_path": Value("string"),
45
+ "language": Value("string"),
46
+ "language_score": Value("float64"),
47
+ "token_count": Value("int64"),
48
+ "score": Value("float64"),
49
+ "int_score": Value("int64"),
50
+ }
51
+ )
52
+ ds = load_dataset("HuggingFaceFW/fineweb-edu", split="train", streaming=True, features=features)
53
+ buf: List[int] = []
54
+ batch: List[List[int]] = []
55
+ seen = 0
56
+ import time as _t
57
+
58
+ t0 = _t.time()
59
+ last = t0
60
+ for ex in ds:
61
+ if seen % shard.mod != shard.rem:
62
+ seen += 1
63
+ continue
64
+ seen += 1
65
+ if report_docs and seen % report_docs == 0:
66
+ dt = _t.time() - last
67
+ print(f"[fwe] seen_docs={seen} dt={dt:.1f}s buf={len(buf)}", flush=True)
68
+ last = _t.time()
69
+ text = ex.get("text") or ""
70
+ if not text:
71
+ continue
72
+ toks = encode(text)
73
+ if not toks:
74
+ continue
75
+ buf.extend(toks)
76
+ while len(buf) >= seq_len:
77
+ seq = buf[:seq_len]
78
+ buf = buf[seq_len:]
79
+ batch.append(seq)
80
+ if len(batch) >= batch_size:
81
+ yield batch[:batch_size]
82
+ batch = batch[batch_size:]
83
+
84
+
85
+ def fineweb_stream_batches_batched(
86
+ encode_batch: Callable[[List[str]], List[List[int]]],
87
+ seq_len: int,
88
+ batch_size: int,
89
+ shard: Shard = Shard(),
90
+ report_docs: int = 1000,
91
+ doc_batch: int = 64,
92
+ ) -> Iterator[List[List[int]]]:
93
+ """Streaming FineWeb‑Edu with batched tokenization and fixed-length packing.
94
+
95
+ - encode_batch: function mapping a list of texts -> list of token id lists
96
+ - Packs contiguous tokens from a rolling buffer into fixed seq_len examples
97
+ - Yields Python lists of shape [batch_size][seq_len]
98
+ """
99
+ try:
100
+ from datasets import load_dataset, Features, Value # type: ignore
101
+ except Exception as e:
102
+ raise RuntimeError("datasets package required. Install with: pip install datasets") from e
103
+
104
+ features = Features(
105
+ {
106
+ "text": Value("string"),
107
+ "id": Value("string"),
108
+ "dump": Value("string"),
109
+ "url": Value("string"),
110
+ "file_path": Value("string"),
111
+ "language": Value("string"),
112
+ "language_score": Value("float64"),
113
+ "token_count": Value("int64"),
114
+ "score": Value("float64"),
115
+ "int_score": Value("int64"),
116
+ }
117
+ )
118
+ ds = load_dataset("HuggingFaceFW/fineweb-edu", split="train", streaming=True, features=features)
119
+
120
+ buf: List[int] = []
121
+ batch: List[List[int]] = []
122
+ seen = 0
123
+ acc_texts: List[str] = []
124
+ import time as _t
125
+ last = _t.time()
126
+ for ex in ds:
127
+ if seen % shard.mod != shard.rem:
128
+ seen += 1
129
+ continue
130
+ seen += 1
131
+ if report_docs and seen % report_docs == 0:
132
+ dt = _t.time() - last
133
+ print(f"[fwe] (batched) seen_docs={seen} dt={dt:.1f}s buf={len(buf)} acc_texts={len(acc_texts)}", flush=True)
134
+ last = _t.time()
135
+ text = ex.get("text") or ""
136
+ if not text:
137
+ continue
138
+ acc_texts.append(text)
139
+ if len(acc_texts) < max(1, int(doc_batch)):
140
+ continue
141
+ # Batched tokenize
142
+ try:
143
+ toks_list = encode_batch(acc_texts)
144
+ except Exception:
145
+ # Fallback to per-doc encode if batch path fails
146
+ toks_list = []
147
+ for t in acc_texts:
148
+ try:
149
+ toks_list.append(encode_batch([t])[0])
150
+ except Exception:
151
+ toks_list.append([])
152
+ acc_texts.clear()
153
+ # Fill rolling buffer and output fixed-length sequences
154
+ for toks in toks_list:
155
+ if not toks:
156
+ continue
157
+ buf.extend(toks)
158
+ while len(buf) >= seq_len:
159
+ seq = buf[:seq_len]
160
+ buf = buf[seq_len:]
161
+ batch.append(seq)
162
+ if len(batch) >= batch_size:
163
+ yield batch[:batch_size]
164
+ batch = batch[batch_size:]
165
+
166
+
167
+ def local_jsonl_or_txt_batches(
168
+ path: str,
169
+ encode: Tokenizer,
170
+ seq_len: int,
171
+ batch_size: int,
172
+ ) -> Iterator[List[List[int]]]:
173
+ is_jsonl = path.endswith(".jsonl")
174
+ buf: List[int] = []
175
+ batch: List[List[int]] = []
176
+ with open(path, encoding="utf-8", errors="ignore") as fh:
177
+ for line in fh:
178
+ line = line.strip()
179
+ if not line:
180
+ continue
181
+ text = line
182
+ if is_jsonl:
183
+ try:
184
+ obj = json.loads(line)
185
+ if isinstance(obj, dict) and isinstance(obj.get("text"), str):
186
+ text = obj["text"]
187
+ except Exception:
188
+ pass
189
+ toks = encode(text)
190
+ if not toks:
191
+ continue
192
+ buf.extend(toks)
193
+ while len(buf) >= seq_len:
194
+ seq = buf[:seq_len]
195
+ buf = buf[seq_len:]
196
+ batch.append(seq)
197
+ if len(batch) >= batch_size:
198
+ yield batch[:batch_size]
199
+ batch = batch[batch_size:]
nsa/kernels/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from __future__ import annotations
nsa/kernels/flash_wrappers.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from nsa.core.debug import log
7
+
8
+
9
+ def _env_bool(name: str, default: bool = False) -> bool:
10
+ v = str(name and __import__("os").getenv(name, "1" if default else "0")).lower()
11
+ return v in ("1", "true", "yes", "on")
12
+
13
+
14
+ def flash_attn_version() -> str | None:
15
+ """Return flash-attn version string if importable, else None."""
16
+ try:
17
+ import flash_attn as _fa # type: ignore
18
+
19
+ return getattr(_fa, "__version__", None)
20
+ except Exception:
21
+ return None
22
+
23
+
24
+ def is_flash_available() -> bool:
25
+ """Return True if flash-attn dense API is importable."""
26
+ try:
27
+ from flash_attn import flash_attn_func # type: ignore
28
+
29
+ _ = flash_attn_func # silence linter
30
+ return True
31
+ except Exception:
32
+ return False
33
+
34
+
35
+ def is_flash_varlen_available() -> bool:
36
+ """Return True if a varlen API is importable (either QKV or KV-packed)."""
37
+ try:
38
+ from flash_attn import flash_attn_varlen_func # type: ignore
39
+
40
+ _ = flash_attn_varlen_func
41
+ return True
42
+ except Exception:
43
+ try:
44
+ from flash_attn import flash_attn_varlen_kvpacked_func # type: ignore
45
+
46
+ _ = flash_attn_varlen_kvpacked_func
47
+ return True
48
+ except Exception:
49
+ return False
50
+
51
+
52
+ def fa2_supported_verbose(
53
+ device: torch.device, dtype: torch.dtype, head_dim: int
54
+ ) -> tuple[bool, str]:
55
+ """
56
+ Conservative capability probe with a reason string for logging.
57
+ We do not hard-fail on dtype, relying on try/except at call sites.
58
+ """
59
+ if device.type != "cuda":
60
+ return False, "device_not_cuda"
61
+ if head_dim % 8 != 0:
62
+ return False, "head_dim_not_multiple_of_8"
63
+ if not (is_flash_varlen_available() or is_flash_available()):
64
+ return False, "flash_attn_not_importable"
65
+ # Optional version floor (best-effort)
66
+ ver = flash_attn_version()
67
+ if ver is None:
68
+ # Unknown version; still allow
69
+ return True, "ok"
70
+ # Allow all known versions; attach for logs
71
+ return True, f"ok_v{ver}"
72
+
73
+
74
+ def fa2_supported(device: torch.device, dtype: torch.dtype, head_dim: int) -> bool:
75
+ ok, _ = fa2_supported_verbose(device, dtype, head_dim)
76
+ return ok
77
+
78
+
79
+ def attention_bgh(
80
+ Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, causal: bool = True
81
+ ) -> torch.Tensor:
82
+ """
83
+ Q: [B,G,h,Dk], K/V: [B,G,S,D*] -> out [B,G,h,Dv]
84
+ Prefer flash-attn if available; fallback to SDPA.
85
+ """
86
+ B, G, h, Dk = Q.shape
87
+ S = K.shape[2]
88
+ # Try FA-2 dense path first
89
+ if is_flash_available():
90
+ try:
91
+ from flash_attn import flash_attn_func # type: ignore
92
+
93
+ # Reshape without materializing copies
94
+ q = Q.transpose(1, 2).reshape(B, G * h, 1, Dk) # [B,G*h,1,Dk]
95
+ k = K.unsqueeze(2).expand(B, G, h, S, Dk).reshape(B, G * h, S, Dk) # [B,G*h,S,Dk]
96
+ v = (
97
+ V.unsqueeze(2).expand(B, G, h, S, V.shape[-1]).reshape(B, G * h, S, V.shape[-1])
98
+ ) # [B,G*h,S,Dv]
99
+ if _env_bool("NSA_DEBUG_TIMING"):
100
+ log("fa2.bgh.path", path="fa2.dense", B=B, G=G, h=h, S=S, Dk=Dk)
101
+ o = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=causal)
102
+ o = o.reshape(B, G, h, -1)
103
+ if not torch.isfinite(o).all():
104
+ log("warn.flash_bgh_nonfinite", path="fa2.dense")
105
+ return torch.nan_to_num(o, nan=0.0)
106
+ except Exception:
107
+ pass
108
+ # SDPA fallback
109
+ if _env_bool("NSA_DEBUG_TIMING"):
110
+ log("fa2.bgh.path", path="sdpa", B=B, G=G, h=h, S=S, Dk=Dk)
111
+ # Expand heads via view/expand to avoid materializing copies
112
+ q2 = Q.reshape(B * G * h, 1, Dk).contiguous()
113
+ k2 = K.unsqueeze(2).expand(B, G, h, S, Dk).reshape(B * G * h, S, Dk).contiguous()
114
+ v2 = (
115
+ V.unsqueeze(2)
116
+ .expand(B, G, h, S, V.shape[-1])
117
+ .reshape(B * G * h, S, V.shape[-1])
118
+ .contiguous()
119
+ )
120
+ attn = F.scaled_dot_product_attention(q2, k2, v2, is_causal=causal)
121
+ o = attn.squeeze(1).reshape(B, G, h, -1)
122
+ return torch.nan_to_num(o, nan=0.0)
123
+
124
+
125
+ def attention_fa2_dense_batch(
126
+ q: torch.Tensor,
127
+ k: torch.Tensor,
128
+ v: torch.Tensor,
129
+ *,
130
+ causal: bool,
131
+ ) -> torch.Tensor:
132
+ """
133
+ Best-effort dense FA-2 call for a batch of independent rows.
134
+ Shapes:
135
+ - q: [N, Tq, h, D]
136
+ - k: [N, Tk, h, D]
137
+ - v: [N, Tk, h, Dv]
138
+ Returns: o [N, Tq, h, Dv]
139
+ Falls back to SDPA if flash-attn unavailable.
140
+ """
141
+ # Ensure contiguous tensors for FA-2
142
+ q = q.contiguous()
143
+ k = k.contiguous()
144
+ v = v.contiguous()
145
+ try:
146
+ from flash_attn import flash_attn_func # type: ignore
147
+
148
+ if _env_bool("NSA_DEBUG_TIMING"):
149
+ log(
150
+ "fa2.batch.path",
151
+ path="fa2.dense",
152
+ N=int(q.shape[0]),
153
+ Tq=int(q.shape[1]),
154
+ Tk=int(k.shape[1]),
155
+ )
156
+ return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=causal)
157
+ except Exception:
158
+ # SDPA fallback per row
159
+ N, Tq, h, D = q.shape
160
+ Tk = k.shape[1]
161
+ Dv = v.shape[-1]
162
+ if _env_bool("NSA_DEBUG_TIMING"):
163
+ log("fa2.batch.path", path="sdpa", N=int(N), Tq=int(Tq), Tk=int(Tk))
164
+ q2 = q.reshape(N * h, Tq, D)
165
+ k2 = k.reshape(N * h, Tk, D)
166
+ v2 = v.reshape(N * h, Tk, Dv)
167
+ out = F.scaled_dot_product_attention(q2, k2, v2, is_causal=causal)
168
+ return out.reshape(N, h, Tq, Dv).permute(0, 2, 1, 3).contiguous()
169
+
170
+
171
+ def attention_fa2_varlen(
172
+ q: torch.Tensor,
173
+ k: torch.Tensor,
174
+ v: torch.Tensor,
175
+ cu_seqlens_q: torch.Tensor,
176
+ cu_seqlens_k: torch.Tensor,
177
+ max_seqlen_q: int,
178
+ max_seqlen_k: int,
179
+ *,
180
+ causal: bool,
181
+ ) -> torch.Tensor:
182
+ """
183
+ Best-effort varlen FA-2 call with separate Q/K/V packing.
184
+ Shapes:
185
+ - q: [total_q, h, D], k: [total_k, h, D], v: [total_k, h, Dv]
186
+ - cu_seqlens_*: int32 [N+1]
187
+ Returns: [total_q, h, Dv] packed output.
188
+ Falls back to dense batching by padding per bucket if varlen API unavailable.
189
+ """
190
+ # Ensure contiguous tensors for FA-2
191
+ q = q.contiguous()
192
+ k = k.contiguous()
193
+ v = v.contiguous()
194
+ try:
195
+ from flash_attn import flash_attn_varlen_func # type: ignore
196
+
197
+ return flash_attn_varlen_func(
198
+ q,
199
+ k,
200
+ v,
201
+ cu_seqlens_q,
202
+ cu_seqlens_k,
203
+ max_seqlen_q,
204
+ max_seqlen_k,
205
+ dropout_p=0.0,
206
+ softmax_scale=None,
207
+ causal=causal,
208
+ )
209
+ except Exception:
210
+ # Try KV-packed API variant
211
+ try:
212
+ from flash_attn import flash_attn_varlen_kvpacked_func # type: ignore
213
+
214
+ # Build KV packed as [total_k, 2, h, D]
215
+ kv_packed = torch.stack([k, v], dim=1).contiguous()
216
+ return flash_attn_varlen_kvpacked_func(
217
+ q,
218
+ kv_packed,
219
+ cu_seqlens_q,
220
+ cu_seqlens_k,
221
+ max_seqlen_q,
222
+ max_seqlen_k,
223
+ dropout_p=0.0,
224
+ softmax_scale=None,
225
+ causal=causal,
226
+ )
227
+ except Exception:
228
+ raise NotImplementedError("FA-2 varlen API not available; caller should fallback")
nsa/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from __future__ import annotations
nsa/model/llama_block_nsa.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from nsa.cache.kv_cache import NSA_KV
7
+ from nsa.core.block_index import build_block_meta
8
+ from nsa.core.nsa_attention import NSAAttention
9
+
10
+
11
+ class RMSNorm(nn.Module):
12
+ def __init__(self, dim: int, eps: float = 1e-6) -> None:
13
+ super().__init__()
14
+ self.weight = nn.Parameter(torch.ones(dim))
15
+ self.eps = eps
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ # x: [B,S,dim]
19
+ rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
20
+ return (x * rms) * self.weight
21
+
22
+
23
+ class MLP(nn.Module):
24
+ def __init__(self, dim: int, hidden_mult: int = 4) -> None:
25
+ super().__init__()
26
+ h = hidden_mult * dim
27
+ self.fc1 = nn.Linear(dim, h, bias=False)
28
+ self.fc2 = nn.Linear(h, dim, bias=False)
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ return self.fc2(F.silu(self.fc1(x)))
32
+
33
+
34
+ class LlamaBlockNSA(nn.Module):
35
+ def __init__(
36
+ self,
37
+ dim: int,
38
+ n_heads: int,
39
+ n_kv_groups: int,
40
+ d_k: int,
41
+ d_v: int,
42
+ l: int = 32,
43
+ d: int = 16,
44
+ l_sel: int = 64,
45
+ n_sel: int = 16,
46
+ w: int = 512,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.norm1 = RMSNorm(dim)
50
+ self.attn = NSAAttention(
51
+ dim=dim,
52
+ n_heads=n_heads,
53
+ n_kv_groups=n_kv_groups,
54
+ d_k=d_k,
55
+ d_v=d_v,
56
+ l=l,
57
+ d=d,
58
+ l_sel=l_sel,
59
+ n_sel=n_sel,
60
+ w=w,
61
+ )
62
+ self.norm2 = RMSNorm(dim)
63
+ self.mlp = MLP(dim)
64
+
65
+ def _build_empty_kv(self, x: torch.Tensor) -> NSA_KV:
66
+ B, S, dim = x.shape
67
+ device = x.device
68
+ G = self.attn.n_kv_groups
69
+ Dk = self.attn.d_k
70
+ Dv = self.attn.d_v
71
+ zeros_k = torch.zeros((B, G, 0, Dk), device=device, dtype=x.dtype)
72
+ zeros_v = torch.zeros((B, G, 0, Dv), device=device, dtype=x.dtype)
73
+ meta = build_block_meta(
74
+ seq_len=0,
75
+ l=self.attn.l,
76
+ d=self.attn.d,
77
+ l_sel=self.attn.l_sel,
78
+ n_sel=self.attn.n_sel,
79
+ w=self.attn.w,
80
+ )
81
+ return NSA_KV(
82
+ K_sel=zeros_k.clone(),
83
+ V_sel=zeros_v.clone(),
84
+ K_win=zeros_k.clone(),
85
+ V_win=zeros_v.clone(),
86
+ K_cmp_raw_seq=zeros_k.clone(),
87
+ V_cmp_raw_seq=zeros_v.clone(),
88
+ K_cmp=zeros_k.clone(),
89
+ V_cmp=zeros_v.clone(),
90
+ win_ptr=torch.zeros((B, G), dtype=torch.int64, device=device),
91
+ cmp_emit_next=torch.zeros((B, G), dtype=torch.int64, device=device),
92
+ meta=meta,
93
+ reads_pred=torch.zeros((0,), dtype=torch.int64, device=device),
94
+ reads_act_total=torch.zeros((0,), dtype=torch.int64, device=device),
95
+ reads_act_sel=torch.zeros((0,), dtype=torch.int64, device=device),
96
+ reads_act_cmp=torch.zeros((0,), dtype=torch.int64, device=device),
97
+ reads_act_win=torch.zeros((0,), dtype=torch.int64, device=device),
98
+ )
99
+
100
+ def forward_attn(self, x: torch.Tensor) -> torch.Tensor:
101
+ """Attention sub-layer with residual.
102
+
103
+ Exposed to allow gradient-checkpoint splits that exclude attention from
104
+ checkpointing when dynamic routing could cause recompute mismatches.
105
+ """
106
+ B, S, dim = x.shape
107
+ res = x
108
+ xn = self.norm1(x)
109
+ kv = self._build_empty_kv(x)
110
+ out, _kv = self.attn(xn, kv=kv, prefill=True)
111
+ return res + out
112
+
113
+ def forward_mlp(self, x: torch.Tensor) -> torch.Tensor:
114
+ """MLP sub-layer with residual.
115
+
116
+ Can be safely checkpointed independently from attention.
117
+ """
118
+ res = x
119
+ return res + self.mlp(self.norm2(x))
120
+
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ # Default monolithic forward preserves prior behavior
123
+ x = self.forward_attn(x)
124
+ x = self.forward_mlp(x)
125
+ return x
126
+
127
+
128
+ class _EmptyKVLike:
129
+ pass
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
tokenization_nsa.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Remote code: configuration and modeling for NSA
2
+ from typing import List, Optional, Dict
3
+ import json
4
+ from transformers import PreTrainedTokenizer
5
+
6
+
7
+ class NSAByteTokenizer(PreTrainedTokenizer):
8
+ """A simple byte-level tokenizer with fixed vocab size 256.
9
+
10
+ - Encodes UTF-8 bytes of the input string as token ids 0..255.
11
+ - No special tokens by default; EOS/PAD can be configured via special tokens map.
12
+ - Decoding uses UTF-8 with replacement for invalid sequences.
13
+ """
14
+
15
+ def __init__(self, **kwargs):
16
+ # Build a stable 256-entry vocab mapping before base init (base may query the vocab)
17
+ self._vocab: Dict[str, int] = {f"<{i}>": i for i in range(256)}
18
+ self._ids_to_tokens: Dict[int, str] = {i: f"<{i}>" for i in range(256)}
19
+ super().__init__(**kwargs)
20
+ # Only return input_ids and attention_mask to avoid unused token_type_ids in generation
21
+ self.model_input_names = ["input_ids", "attention_mask"]
22
+
23
+ @property
24
+ def vocab_size(self) -> int: # type: ignore[override]
25
+ return 256
26
+
27
+ def get_vocab(self) -> Dict[str, int]: # type: ignore[override]
28
+ return dict(self._vocab)
29
+
30
+ def _tokenize(self, text: str) -> List[str]: # type: ignore[override]
31
+ data = text.encode("utf-8", errors="replace")
32
+ return [f"<{b}>" for b in data]
33
+
34
+ def _convert_token_to_id(self, token: str) -> int: # type: ignore[override]
35
+ if token in self._vocab:
36
+ return self._vocab[token]
37
+ # Fallback: try parse numeric inside <..>
38
+ if token.startswith("<") and token.endswith(">"):
39
+ try:
40
+ v = int(token[1:-1])
41
+ if 0 <= v < 256:
42
+ return v
43
+ except Exception:
44
+ pass
45
+ return 0
46
+
47
+ def _convert_id_to_token(self, index: int) -> str: # type: ignore[override]
48
+ return self._ids_to_tokens.get(int(index) % 256, "<0>")
49
+
50
+ def convert_tokens_to_string(self, tokens: List[str]) -> str: # type: ignore[override]
51
+ bs = []
52
+ for t in tokens:
53
+ if t in self._vocab:
54
+ bs.append(self._vocab[t])
55
+ else:
56
+ try:
57
+ if t.startswith("<") and t.endswith(">"):
58
+ v = int(t[1:-1])
59
+ if 0 <= v < 256:
60
+ bs.append(v)
61
+ continue
62
+ except Exception:
63
+ pass
64
+ return bytes(bs).decode("utf-8", errors="replace")
65
+
66
+ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: # type: ignore[override]
67
+ if token_ids_1 is None:
68
+ return token_ids_0
69
+ return token_ids_0 + token_ids_1
70
+
71
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None): # type: ignore[override]
72
+ # Nothing to save besides special tokens map handled by the base class.
73
+ return (), ()
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "NSAByteTokenizer",
3
+ "model_max_length": 2048,
4
+ "chat_template": "{% for m in messages %}{% if m['role']=='user' %}<|user|>{{ m['content'] }}\n{% elif m['role']=='assistant' %}<|assistant|>{{ m['content'] }}\n{% endif %}{% endfor %}<|assistant|>",
5
+ "auto_map": {
6
+ "AutoTokenizer": [
7
+ "tokenization_nsa.NSAByteTokenizer",
8
+ null
9
+ ]
10
+ }
11
+ }