NSA 117M initial export
Browse files- LICENSE +1 -0
- README.md +61 -0
- config.json +36 -0
- configuration_nsa.py +40 -0
- logs/logs_extra_keys.txt +49 -0
- logs/logs_mapping.json +229 -0
- logs/logs_missing_keys.txt +26 -0
- model.safetensors +3 -0
- modeling_nsa.py +311 -0
- nsa/__init__.py +1 -0
- nsa/cache/__init__.py +1 -0
- nsa/cache/kv_cache.py +66 -0
- nsa/core/README.md +27 -0
- nsa/core/__init__.py +1 -0
- nsa/core/attention_kernels.py +1403 -0
- nsa/core/block_index.py +100 -0
- nsa/core/collate.py +45 -0
- nsa/core/compress_pool.py +39 -0
- nsa/core/debug.py +44 -0
- nsa/core/flags.py +80 -0
- nsa/core/nsa_attention.py +1850 -0
- nsa/core/packing.py +114 -0
- nsa/core/rope.py +51 -0
- nsa/core/selection_scorer.py +759 -0
- nsa/data_pipeline.py +199 -0
- nsa/kernels/__init__.py +1 -0
- nsa/kernels/flash_wrappers.py +228 -0
- nsa/model/__init__.py +1 -0
- nsa/model/llama_block_nsa.py +129 -0
- special_tokens_map.json +1 -0
- tokenization_nsa.py +73 -0
- tokenizer_config.json +11 -0
LICENSE
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Apache-2.0
|
README.md
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
tags:
|
| 6 |
+
- nsa
|
| 7 |
+
- sparse-attention
|
| 8 |
+
- 117m
|
| 9 |
+
datasets:
|
| 10 |
+
- fineweb-edu
|
| 11 |
+
library_name: transformers
|
| 12 |
+
pipeline_tag: text-generation
|
| 13 |
+
base_model: byte-256
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# NSA 117M (FineWeb-Edu) — Remote Code
|
| 17 |
+
|
| 18 |
+
This repository contains a 117M NSA decoder-only model with remote code. It exposes `NSAConfig` and `NSAForCausalLM` so you can load via:
|
| 19 |
+
|
| 20 |
+
```python
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
+
m = AutoModelForCausalLM.from_pretrained("seconds-0/nsa-117m-byte", trust_remote_code=True)
|
| 23 |
+
t = AutoTokenizer.from_pretrained("seconds-0/nsa-117m-byte")
|
| 24 |
+
out = m.generate(**t("Hello", return_tensors="pt"), max_new_tokens=16)
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## What is NSA
|
| 28 |
+
|
| 29 |
+
Native Sparse Attention (NSA) combines three branches — compressed (cmp), selected (sel), and sliding window (win) — mixed by a learned gate. The 117M configuration uses SDPA everywhere and keeps strict causality.
|
| 30 |
+
|
| 31 |
+
Architecture (overview):
|
| 32 |
+
- cmp: compressed blocks (tile length l, stride d) attended with causal masks
|
| 33 |
+
- sel: top-n selection over blockized keys (block l′, n ranges per step)
|
| 34 |
+
- win: sliding window attention of size w
|
| 35 |
+
- gate: small MLP (zero-initialized last layer), softmax(τ=1.0)
|
| 36 |
+
|
| 37 |
+
Defaults: l=32, d=16, l′=64, n=16, w=512; GQA groups=2.
|
| 38 |
+
|
| 39 |
+
## Performance & Metrics (example targets)
|
| 40 |
+
|
| 41 |
+
- A100 40GB: ≥600 tok/s; TTFT ≤ 350 ms (batch=1, seq=128)
|
| 42 |
+
- RTX 4090: ≥400 tok/s; TTFT ≤ 450 ms
|
| 43 |
+
- CPU: ≥10 tok/s; TTFT ≤ 2.0 s
|
| 44 |
+
|
| 45 |
+
## Intended Use / Limitations
|
| 46 |
+
|
| 47 |
+
- Toy assistant and demos; not suitable for high-stakes use.
|
| 48 |
+
|
| 49 |
+
## Memory Budget (KV Cache)
|
| 50 |
+
|
| 51 |
+
- Standard LM approx: Mem ≈ t × H × (d_k + d_v) × bytes_per_elem
|
| 52 |
+
- NSA decode (M0): Mem ≈ (min(w, t) + n × l′) × H × (d_k + d_v) × bytes_per_elem
|
| 53 |
+
- Example (w=512, n=16, l′=64): tokens_cached ≈ min(512, t) + 1024 (FP16 → a few MiB for 117M dims)
|
| 54 |
+
|
| 55 |
+
## Notes
|
| 56 |
+
|
| 57 |
+
- Tokenizer: byte-level tokenizer (vocab=256). This is not GPT‑2/BPE; input/output are raw UTF‑8 bytes.
|
| 58 |
+
- Generation cache: no KV cache in v1 (slower decode for long sequences). Planned follow‑up.
|
| 59 |
+
- Gate: initialized to uniform mixing by design (zero‑init last layer); differs from trained gate topology.
|
| 60 |
+
- Remote code uses SDPA-only paths and includes a safe fallback block if NSA is forcibly disabled via env.
|
| 61 |
+
|
config.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "nsa",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"NSAForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_nsa.NSAConfig",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_nsa.NSAForCausalLM",
|
| 9 |
+
"AutoTokenizer": [
|
| 10 |
+
"tokenization_nsa.NSAByteTokenizer",
|
| 11 |
+
null
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
"vocab_size": 256,
|
| 15 |
+
"hidden_size": 768,
|
| 16 |
+
"num_hidden_layers": 12,
|
| 17 |
+
"num_attention_heads": 12,
|
| 18 |
+
"n_kv_groups": 2,
|
| 19 |
+
"d_k": 64,
|
| 20 |
+
"d_v": 64,
|
| 21 |
+
"max_position_embeddings": 2048,
|
| 22 |
+
"rope_theta": 10000,
|
| 23 |
+
"nsa": {
|
| 24 |
+
"branches": [
|
| 25 |
+
"cmp",
|
| 26 |
+
"sel",
|
| 27 |
+
"win"
|
| 28 |
+
],
|
| 29 |
+
"window": 512,
|
| 30 |
+
"gqa_groups": 2,
|
| 31 |
+
"block": 32,
|
| 32 |
+
"stride": 16,
|
| 33 |
+
"sel_block": 64,
|
| 34 |
+
"sel_top_n": 16
|
| 35 |
+
}
|
| 36 |
+
}
|
configuration_nsa.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Remote code: configuration and modeling for NSA
|
| 2 |
+
from transformers import PretrainedConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class NSAConfig(PretrainedConfig):
|
| 6 |
+
model_type = "nsa"
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
vocab_size=50257,
|
| 11 |
+
hidden_size=768,
|
| 12 |
+
num_hidden_layers=12,
|
| 13 |
+
num_attention_heads=12,
|
| 14 |
+
n_kv_groups=1,
|
| 15 |
+
d_k=64,
|
| 16 |
+
d_v=64,
|
| 17 |
+
max_position_embeddings=2048,
|
| 18 |
+
rope_theta=10000,
|
| 19 |
+
nsa=None,
|
| 20 |
+
**kwargs,
|
| 21 |
+
):
|
| 22 |
+
super().__init__(**kwargs)
|
| 23 |
+
self.vocab_size = vocab_size
|
| 24 |
+
self.hidden_size = hidden_size
|
| 25 |
+
self.num_hidden_layers = num_hidden_layers
|
| 26 |
+
self.num_attention_heads = num_attention_heads
|
| 27 |
+
self.n_kv_groups = n_kv_groups
|
| 28 |
+
self.d_k = d_k
|
| 29 |
+
self.d_v = d_v
|
| 30 |
+
self.max_position_embeddings = max_position_embeddings
|
| 31 |
+
self.rope_theta = rope_theta
|
| 32 |
+
self.nsa = nsa or {
|
| 33 |
+
"branches": ["cmp", "sel", "win"],
|
| 34 |
+
"window": 512,
|
| 35 |
+
"gqa_groups": n_kv_groups,
|
| 36 |
+
"block": 32,
|
| 37 |
+
"stride": 16,
|
| 38 |
+
"sel_block": 64,
|
| 39 |
+
"sel_top_n": 16,
|
| 40 |
+
}
|
logs/logs_extra_keys.txt
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
blocks.0.attn.gate.fc1.bias
|
| 2 |
+
blocks.0.attn.gate.fc1.weight
|
| 3 |
+
blocks.0.attn.gate.fc2.bias
|
| 4 |
+
blocks.0.attn.gate.fc2.weight
|
| 5 |
+
blocks.1.attn.gate.fc1.bias
|
| 6 |
+
blocks.1.attn.gate.fc1.weight
|
| 7 |
+
blocks.1.attn.gate.fc2.bias
|
| 8 |
+
blocks.1.attn.gate.fc2.weight
|
| 9 |
+
blocks.10.attn.gate.fc1.bias
|
| 10 |
+
blocks.10.attn.gate.fc1.weight
|
| 11 |
+
blocks.10.attn.gate.fc2.bias
|
| 12 |
+
blocks.10.attn.gate.fc2.weight
|
| 13 |
+
blocks.11.attn.gate.fc1.bias
|
| 14 |
+
blocks.11.attn.gate.fc1.weight
|
| 15 |
+
blocks.11.attn.gate.fc2.bias
|
| 16 |
+
blocks.11.attn.gate.fc2.weight
|
| 17 |
+
blocks.2.attn.gate.fc1.bias
|
| 18 |
+
blocks.2.attn.gate.fc1.weight
|
| 19 |
+
blocks.2.attn.gate.fc2.bias
|
| 20 |
+
blocks.2.attn.gate.fc2.weight
|
| 21 |
+
blocks.3.attn.gate.fc1.bias
|
| 22 |
+
blocks.3.attn.gate.fc1.weight
|
| 23 |
+
blocks.3.attn.gate.fc2.bias
|
| 24 |
+
blocks.3.attn.gate.fc2.weight
|
| 25 |
+
blocks.4.attn.gate.fc1.bias
|
| 26 |
+
blocks.4.attn.gate.fc1.weight
|
| 27 |
+
blocks.4.attn.gate.fc2.bias
|
| 28 |
+
blocks.4.attn.gate.fc2.weight
|
| 29 |
+
blocks.5.attn.gate.fc1.bias
|
| 30 |
+
blocks.5.attn.gate.fc1.weight
|
| 31 |
+
blocks.5.attn.gate.fc2.bias
|
| 32 |
+
blocks.5.attn.gate.fc2.weight
|
| 33 |
+
blocks.6.attn.gate.fc1.bias
|
| 34 |
+
blocks.6.attn.gate.fc1.weight
|
| 35 |
+
blocks.6.attn.gate.fc2.bias
|
| 36 |
+
blocks.6.attn.gate.fc2.weight
|
| 37 |
+
blocks.7.attn.gate.fc1.bias
|
| 38 |
+
blocks.7.attn.gate.fc1.weight
|
| 39 |
+
blocks.7.attn.gate.fc2.bias
|
| 40 |
+
blocks.7.attn.gate.fc2.weight
|
| 41 |
+
blocks.8.attn.gate.fc1.bias
|
| 42 |
+
blocks.8.attn.gate.fc1.weight
|
| 43 |
+
blocks.8.attn.gate.fc2.bias
|
| 44 |
+
blocks.8.attn.gate.fc2.weight
|
| 45 |
+
blocks.9.attn.gate.fc1.bias
|
| 46 |
+
blocks.9.attn.gate.fc1.weight
|
| 47 |
+
blocks.9.attn.gate.fc2.bias
|
| 48 |
+
blocks.9.attn.gate.fc2.weight
|
| 49 |
+
norm_f.weight
|
logs/logs_mapping.json
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mapped": [
|
| 3 |
+
"model.blocks.0.attn.W_K_cmp.weight",
|
| 4 |
+
"model.blocks.0.attn.W_K_sel.weight",
|
| 5 |
+
"model.blocks.0.attn.W_K_win.weight",
|
| 6 |
+
"model.blocks.0.attn.W_Q.weight",
|
| 7 |
+
"model.blocks.0.attn.W_V_cmp.weight",
|
| 8 |
+
"model.blocks.0.attn.W_V_sel.weight",
|
| 9 |
+
"model.blocks.0.attn.W_V_win.weight",
|
| 10 |
+
"model.blocks.0.attn.out.weight",
|
| 11 |
+
"model.blocks.0.mlp.fc1.weight",
|
| 12 |
+
"model.blocks.0.mlp.fc2.weight",
|
| 13 |
+
"model.blocks.0.norm1.weight",
|
| 14 |
+
"model.blocks.0.norm2.weight",
|
| 15 |
+
"model.blocks.1.attn.W_K_cmp.weight",
|
| 16 |
+
"model.blocks.1.attn.W_K_sel.weight",
|
| 17 |
+
"model.blocks.1.attn.W_K_win.weight",
|
| 18 |
+
"model.blocks.1.attn.W_Q.weight",
|
| 19 |
+
"model.blocks.1.attn.W_V_cmp.weight",
|
| 20 |
+
"model.blocks.1.attn.W_V_sel.weight",
|
| 21 |
+
"model.blocks.1.attn.W_V_win.weight",
|
| 22 |
+
"model.blocks.1.attn.out.weight",
|
| 23 |
+
"model.blocks.1.mlp.fc1.weight",
|
| 24 |
+
"model.blocks.1.mlp.fc2.weight",
|
| 25 |
+
"model.blocks.1.norm1.weight",
|
| 26 |
+
"model.blocks.1.norm2.weight",
|
| 27 |
+
"model.blocks.10.attn.W_K_cmp.weight",
|
| 28 |
+
"model.blocks.10.attn.W_K_sel.weight",
|
| 29 |
+
"model.blocks.10.attn.W_K_win.weight",
|
| 30 |
+
"model.blocks.10.attn.W_Q.weight",
|
| 31 |
+
"model.blocks.10.attn.W_V_cmp.weight",
|
| 32 |
+
"model.blocks.10.attn.W_V_sel.weight",
|
| 33 |
+
"model.blocks.10.attn.W_V_win.weight",
|
| 34 |
+
"model.blocks.10.attn.out.weight",
|
| 35 |
+
"model.blocks.10.mlp.fc1.weight",
|
| 36 |
+
"model.blocks.10.mlp.fc2.weight",
|
| 37 |
+
"model.blocks.10.norm1.weight",
|
| 38 |
+
"model.blocks.10.norm2.weight",
|
| 39 |
+
"model.blocks.11.attn.W_K_cmp.weight",
|
| 40 |
+
"model.blocks.11.attn.W_K_sel.weight",
|
| 41 |
+
"model.blocks.11.attn.W_K_win.weight",
|
| 42 |
+
"model.blocks.11.attn.W_Q.weight",
|
| 43 |
+
"model.blocks.11.attn.W_V_cmp.weight",
|
| 44 |
+
"model.blocks.11.attn.W_V_sel.weight",
|
| 45 |
+
"model.blocks.11.attn.W_V_win.weight",
|
| 46 |
+
"model.blocks.11.attn.out.weight",
|
| 47 |
+
"model.blocks.11.mlp.fc1.weight",
|
| 48 |
+
"model.blocks.11.mlp.fc2.weight",
|
| 49 |
+
"model.blocks.11.norm1.weight",
|
| 50 |
+
"model.blocks.11.norm2.weight",
|
| 51 |
+
"model.blocks.2.attn.W_K_cmp.weight",
|
| 52 |
+
"model.blocks.2.attn.W_K_sel.weight",
|
| 53 |
+
"model.blocks.2.attn.W_K_win.weight",
|
| 54 |
+
"model.blocks.2.attn.W_Q.weight",
|
| 55 |
+
"model.blocks.2.attn.W_V_cmp.weight",
|
| 56 |
+
"model.blocks.2.attn.W_V_sel.weight",
|
| 57 |
+
"model.blocks.2.attn.W_V_win.weight",
|
| 58 |
+
"model.blocks.2.attn.out.weight",
|
| 59 |
+
"model.blocks.2.mlp.fc1.weight",
|
| 60 |
+
"model.blocks.2.mlp.fc2.weight",
|
| 61 |
+
"model.blocks.2.norm1.weight",
|
| 62 |
+
"model.blocks.2.norm2.weight",
|
| 63 |
+
"model.blocks.3.attn.W_K_cmp.weight",
|
| 64 |
+
"model.blocks.3.attn.W_K_sel.weight",
|
| 65 |
+
"model.blocks.3.attn.W_K_win.weight",
|
| 66 |
+
"model.blocks.3.attn.W_Q.weight",
|
| 67 |
+
"model.blocks.3.attn.W_V_cmp.weight",
|
| 68 |
+
"model.blocks.3.attn.W_V_sel.weight",
|
| 69 |
+
"model.blocks.3.attn.W_V_win.weight",
|
| 70 |
+
"model.blocks.3.attn.out.weight",
|
| 71 |
+
"model.blocks.3.mlp.fc1.weight",
|
| 72 |
+
"model.blocks.3.mlp.fc2.weight",
|
| 73 |
+
"model.blocks.3.norm1.weight",
|
| 74 |
+
"model.blocks.3.norm2.weight",
|
| 75 |
+
"model.blocks.4.attn.W_K_cmp.weight",
|
| 76 |
+
"model.blocks.4.attn.W_K_sel.weight",
|
| 77 |
+
"model.blocks.4.attn.W_K_win.weight",
|
| 78 |
+
"model.blocks.4.attn.W_Q.weight",
|
| 79 |
+
"model.blocks.4.attn.W_V_cmp.weight",
|
| 80 |
+
"model.blocks.4.attn.W_V_sel.weight",
|
| 81 |
+
"model.blocks.4.attn.W_V_win.weight",
|
| 82 |
+
"model.blocks.4.attn.out.weight",
|
| 83 |
+
"model.blocks.4.mlp.fc1.weight",
|
| 84 |
+
"model.blocks.4.mlp.fc2.weight",
|
| 85 |
+
"model.blocks.4.norm1.weight",
|
| 86 |
+
"model.blocks.4.norm2.weight",
|
| 87 |
+
"model.blocks.5.attn.W_K_cmp.weight",
|
| 88 |
+
"model.blocks.5.attn.W_K_sel.weight",
|
| 89 |
+
"model.blocks.5.attn.W_K_win.weight",
|
| 90 |
+
"model.blocks.5.attn.W_Q.weight",
|
| 91 |
+
"model.blocks.5.attn.W_V_cmp.weight",
|
| 92 |
+
"model.blocks.5.attn.W_V_sel.weight",
|
| 93 |
+
"model.blocks.5.attn.W_V_win.weight",
|
| 94 |
+
"model.blocks.5.attn.out.weight",
|
| 95 |
+
"model.blocks.5.mlp.fc1.weight",
|
| 96 |
+
"model.blocks.5.mlp.fc2.weight",
|
| 97 |
+
"model.blocks.5.norm1.weight",
|
| 98 |
+
"model.blocks.5.norm2.weight",
|
| 99 |
+
"model.blocks.6.attn.W_K_cmp.weight",
|
| 100 |
+
"model.blocks.6.attn.W_K_sel.weight",
|
| 101 |
+
"model.blocks.6.attn.W_K_win.weight",
|
| 102 |
+
"model.blocks.6.attn.W_Q.weight",
|
| 103 |
+
"model.blocks.6.attn.W_V_cmp.weight",
|
| 104 |
+
"model.blocks.6.attn.W_V_sel.weight",
|
| 105 |
+
"model.blocks.6.attn.W_V_win.weight",
|
| 106 |
+
"model.blocks.6.attn.out.weight",
|
| 107 |
+
"model.blocks.6.mlp.fc1.weight",
|
| 108 |
+
"model.blocks.6.mlp.fc2.weight",
|
| 109 |
+
"model.blocks.6.norm1.weight",
|
| 110 |
+
"model.blocks.6.norm2.weight",
|
| 111 |
+
"model.blocks.7.attn.W_K_cmp.weight",
|
| 112 |
+
"model.blocks.7.attn.W_K_sel.weight",
|
| 113 |
+
"model.blocks.7.attn.W_K_win.weight",
|
| 114 |
+
"model.blocks.7.attn.W_Q.weight",
|
| 115 |
+
"model.blocks.7.attn.W_V_cmp.weight",
|
| 116 |
+
"model.blocks.7.attn.W_V_sel.weight",
|
| 117 |
+
"model.blocks.7.attn.W_V_win.weight",
|
| 118 |
+
"model.blocks.7.attn.out.weight",
|
| 119 |
+
"model.blocks.7.mlp.fc1.weight",
|
| 120 |
+
"model.blocks.7.mlp.fc2.weight",
|
| 121 |
+
"model.blocks.7.norm1.weight",
|
| 122 |
+
"model.blocks.7.norm2.weight",
|
| 123 |
+
"model.blocks.8.attn.W_K_cmp.weight",
|
| 124 |
+
"model.blocks.8.attn.W_K_sel.weight",
|
| 125 |
+
"model.blocks.8.attn.W_K_win.weight",
|
| 126 |
+
"model.blocks.8.attn.W_Q.weight",
|
| 127 |
+
"model.blocks.8.attn.W_V_cmp.weight",
|
| 128 |
+
"model.blocks.8.attn.W_V_sel.weight",
|
| 129 |
+
"model.blocks.8.attn.W_V_win.weight",
|
| 130 |
+
"model.blocks.8.attn.out.weight",
|
| 131 |
+
"model.blocks.8.mlp.fc1.weight",
|
| 132 |
+
"model.blocks.8.mlp.fc2.weight",
|
| 133 |
+
"model.blocks.8.norm1.weight",
|
| 134 |
+
"model.blocks.8.norm2.weight",
|
| 135 |
+
"model.blocks.9.attn.W_K_cmp.weight",
|
| 136 |
+
"model.blocks.9.attn.W_K_sel.weight",
|
| 137 |
+
"model.blocks.9.attn.W_K_win.weight",
|
| 138 |
+
"model.blocks.9.attn.W_Q.weight",
|
| 139 |
+
"model.blocks.9.attn.W_V_cmp.weight",
|
| 140 |
+
"model.blocks.9.attn.W_V_sel.weight",
|
| 141 |
+
"model.blocks.9.attn.W_V_win.weight",
|
| 142 |
+
"model.blocks.9.attn.out.weight",
|
| 143 |
+
"model.blocks.9.mlp.fc1.weight",
|
| 144 |
+
"model.blocks.9.mlp.fc2.weight",
|
| 145 |
+
"model.blocks.9.norm1.weight",
|
| 146 |
+
"model.blocks.9.norm2.weight",
|
| 147 |
+
"model.embed.weight",
|
| 148 |
+
"model.lm_head.weight"
|
| 149 |
+
],
|
| 150 |
+
"missing": [
|
| 151 |
+
"model.blocks.0.attn.g1.weight",
|
| 152 |
+
"model.blocks.0.attn.g2.weight",
|
| 153 |
+
"model.blocks.1.attn.g1.weight",
|
| 154 |
+
"model.blocks.1.attn.g2.weight",
|
| 155 |
+
"model.blocks.10.attn.g1.weight",
|
| 156 |
+
"model.blocks.10.attn.g2.weight",
|
| 157 |
+
"model.blocks.11.attn.g1.weight",
|
| 158 |
+
"model.blocks.11.attn.g2.weight",
|
| 159 |
+
"model.blocks.2.attn.g1.weight",
|
| 160 |
+
"model.blocks.2.attn.g2.weight",
|
| 161 |
+
"model.blocks.3.attn.g1.weight",
|
| 162 |
+
"model.blocks.3.attn.g2.weight",
|
| 163 |
+
"model.blocks.4.attn.g1.weight",
|
| 164 |
+
"model.blocks.4.attn.g2.weight",
|
| 165 |
+
"model.blocks.5.attn.g1.weight",
|
| 166 |
+
"model.blocks.5.attn.g2.weight",
|
| 167 |
+
"model.blocks.6.attn.g1.weight",
|
| 168 |
+
"model.blocks.6.attn.g2.weight",
|
| 169 |
+
"model.blocks.7.attn.g1.weight",
|
| 170 |
+
"model.blocks.7.attn.g2.weight",
|
| 171 |
+
"model.blocks.8.attn.g1.weight",
|
| 172 |
+
"model.blocks.8.attn.g2.weight",
|
| 173 |
+
"model.blocks.9.attn.g1.weight",
|
| 174 |
+
"model.blocks.9.attn.g2.weight",
|
| 175 |
+
"model.norm.bias",
|
| 176 |
+
"model.norm.weight"
|
| 177 |
+
],
|
| 178 |
+
"extra": [
|
| 179 |
+
"blocks.0.attn.gate.fc1.bias",
|
| 180 |
+
"blocks.0.attn.gate.fc1.weight",
|
| 181 |
+
"blocks.0.attn.gate.fc2.bias",
|
| 182 |
+
"blocks.0.attn.gate.fc2.weight",
|
| 183 |
+
"blocks.1.attn.gate.fc1.bias",
|
| 184 |
+
"blocks.1.attn.gate.fc1.weight",
|
| 185 |
+
"blocks.1.attn.gate.fc2.bias",
|
| 186 |
+
"blocks.1.attn.gate.fc2.weight",
|
| 187 |
+
"blocks.10.attn.gate.fc1.bias",
|
| 188 |
+
"blocks.10.attn.gate.fc1.weight",
|
| 189 |
+
"blocks.10.attn.gate.fc2.bias",
|
| 190 |
+
"blocks.10.attn.gate.fc2.weight",
|
| 191 |
+
"blocks.11.attn.gate.fc1.bias",
|
| 192 |
+
"blocks.11.attn.gate.fc1.weight",
|
| 193 |
+
"blocks.11.attn.gate.fc2.bias",
|
| 194 |
+
"blocks.11.attn.gate.fc2.weight",
|
| 195 |
+
"blocks.2.attn.gate.fc1.bias",
|
| 196 |
+
"blocks.2.attn.gate.fc1.weight",
|
| 197 |
+
"blocks.2.attn.gate.fc2.bias",
|
| 198 |
+
"blocks.2.attn.gate.fc2.weight",
|
| 199 |
+
"blocks.3.attn.gate.fc1.bias",
|
| 200 |
+
"blocks.3.attn.gate.fc1.weight",
|
| 201 |
+
"blocks.3.attn.gate.fc2.bias",
|
| 202 |
+
"blocks.3.attn.gate.fc2.weight",
|
| 203 |
+
"blocks.4.attn.gate.fc1.bias",
|
| 204 |
+
"blocks.4.attn.gate.fc1.weight",
|
| 205 |
+
"blocks.4.attn.gate.fc2.bias",
|
| 206 |
+
"blocks.4.attn.gate.fc2.weight",
|
| 207 |
+
"blocks.5.attn.gate.fc1.bias",
|
| 208 |
+
"blocks.5.attn.gate.fc1.weight",
|
| 209 |
+
"blocks.5.attn.gate.fc2.bias",
|
| 210 |
+
"blocks.5.attn.gate.fc2.weight",
|
| 211 |
+
"blocks.6.attn.gate.fc1.bias",
|
| 212 |
+
"blocks.6.attn.gate.fc1.weight",
|
| 213 |
+
"blocks.6.attn.gate.fc2.bias",
|
| 214 |
+
"blocks.6.attn.gate.fc2.weight",
|
| 215 |
+
"blocks.7.attn.gate.fc1.bias",
|
| 216 |
+
"blocks.7.attn.gate.fc1.weight",
|
| 217 |
+
"blocks.7.attn.gate.fc2.bias",
|
| 218 |
+
"blocks.7.attn.gate.fc2.weight",
|
| 219 |
+
"blocks.8.attn.gate.fc1.bias",
|
| 220 |
+
"blocks.8.attn.gate.fc1.weight",
|
| 221 |
+
"blocks.8.attn.gate.fc2.bias",
|
| 222 |
+
"blocks.8.attn.gate.fc2.weight",
|
| 223 |
+
"blocks.9.attn.gate.fc1.bias",
|
| 224 |
+
"blocks.9.attn.gate.fc1.weight",
|
| 225 |
+
"blocks.9.attn.gate.fc2.bias",
|
| 226 |
+
"blocks.9.attn.gate.fc2.weight",
|
| 227 |
+
"norm_f.weight"
|
| 228 |
+
]
|
| 229 |
+
}
|
logs/logs_missing_keys.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model.blocks.0.attn.g1.weight
|
| 2 |
+
model.blocks.0.attn.g2.weight
|
| 3 |
+
model.blocks.1.attn.g1.weight
|
| 4 |
+
model.blocks.1.attn.g2.weight
|
| 5 |
+
model.blocks.10.attn.g1.weight
|
| 6 |
+
model.blocks.10.attn.g2.weight
|
| 7 |
+
model.blocks.11.attn.g1.weight
|
| 8 |
+
model.blocks.11.attn.g2.weight
|
| 9 |
+
model.blocks.2.attn.g1.weight
|
| 10 |
+
model.blocks.2.attn.g2.weight
|
| 11 |
+
model.blocks.3.attn.g1.weight
|
| 12 |
+
model.blocks.3.attn.g2.weight
|
| 13 |
+
model.blocks.4.attn.g1.weight
|
| 14 |
+
model.blocks.4.attn.g2.weight
|
| 15 |
+
model.blocks.5.attn.g1.weight
|
| 16 |
+
model.blocks.5.attn.g2.weight
|
| 17 |
+
model.blocks.6.attn.g1.weight
|
| 18 |
+
model.blocks.6.attn.g2.weight
|
| 19 |
+
model.blocks.7.attn.g1.weight
|
| 20 |
+
model.blocks.7.attn.g2.weight
|
| 21 |
+
model.blocks.8.attn.g1.weight
|
| 22 |
+
model.blocks.8.attn.g2.weight
|
| 23 |
+
model.blocks.9.attn.g1.weight
|
| 24 |
+
model.blocks.9.attn.g2.weight
|
| 25 |
+
model.norm.bias
|
| 26 |
+
model.norm.weight
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92e303af798306020bcf0b1a6293a9e88027887b70d61d110fc4cba274cedf66
|
| 3 |
+
size 320203152
|
modeling_nsa.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Remote code: configuration and modeling for NSA
|
| 2 |
+
import math
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from transformers import PreTrainedModel
|
| 8 |
+
from transformers.generation.utils import GenerationMixin
|
| 9 |
+
from transformers.modeling_outputs import CausalLMOutput
|
| 10 |
+
|
| 11 |
+
from .configuration_nsa import NSAConfig
|
| 12 |
+
_HAS_NSA = False # Embedded NSA is provided below; no external import required.
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RMSNorm(nn.Module):
|
| 16 |
+
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 19 |
+
self.eps = eps
|
| 20 |
+
|
| 21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
|
| 23 |
+
return (x * rms) * self.weight
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MLP(nn.Module):
|
| 27 |
+
def __init__(self, dim: int, hidden_mult: int = 4) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
h = hidden_mult * dim
|
| 30 |
+
self.fc1 = nn.Linear(dim, h, bias=False)
|
| 31 |
+
self.fc2 = nn.Linear(h, dim, bias=False)
|
| 32 |
+
|
| 33 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
return self.fc2(torch.nn.functional.silu(self.fc1(x)))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _rope(q: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
B, S, D = q.shape[0], q.shape[2], q.shape[-1]
|
| 39 |
+
if D % 2 != 0:
|
| 40 |
+
return q
|
| 41 |
+
device = q.device
|
| 42 |
+
half = D // 2
|
| 43 |
+
pos = torch.arange(S, device=device).float().unsqueeze(-1)
|
| 44 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, half, device=device).float() / half))
|
| 45 |
+
angles = pos * inv_freq
|
| 46 |
+
cos = angles.cos().view(1, 1, S, half)
|
| 47 |
+
sin = angles.sin().view(1, 1, S, half)
|
| 48 |
+
q1, q2 = q[..., :half], q[..., half:]
|
| 49 |
+
return torch.cat([q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _avg_pool_time(x: torch.Tensor, kernel: int, stride: int) -> torch.Tensor:
|
| 53 |
+
if x.shape[2] < kernel:
|
| 54 |
+
return x[..., :0, :]
|
| 55 |
+
xt = x.permute(0, 3, 1, 2).contiguous()
|
| 56 |
+
y = torch.nn.functional.avg_pool2d(xt, kernel_size=(1, kernel), stride=(1, stride))
|
| 57 |
+
return y.permute(0, 2, 3, 1).contiguous()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _window_mask(q: torch.Tensor, S: int, w: int) -> torch.Tensor:
|
| 61 |
+
B, h = q.shape[0], q.shape[1]
|
| 62 |
+
device = q.device
|
| 63 |
+
row = torch.arange(S, device=device).view(S, 1)
|
| 64 |
+
col = torch.arange(S, device=device).view(1, S)
|
| 65 |
+
allowed = (col <= row) & (col >= (row - (w - 1)))
|
| 66 |
+
M = torch.full((S, S), float('-inf'), device=device, dtype=q.dtype)
|
| 67 |
+
M.masked_fill_(allowed, 0.0)
|
| 68 |
+
return M.view(1, 1, S, S).expand(B, h, S, S)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _selection_blocks(scores: torch.Tensor, l_sel: int, n_sel: int) -> torch.Tensor:
|
| 72 |
+
B, h, S = scores.shape
|
| 73 |
+
n_blocks = max(1, (S + l_sel - 1) // l_sel)
|
| 74 |
+
# Pad to multiple of l_sel
|
| 75 |
+
pad = n_blocks * l_sel - S
|
| 76 |
+
if pad > 0:
|
| 77 |
+
scores = torch.nn.functional.pad(scores, (0, pad), value=-1e9)
|
| 78 |
+
blk_scores = scores.view(B, h, n_blocks, l_sel).max(dim=-1).values
|
| 79 |
+
k = min(n_sel, n_blocks)
|
| 80 |
+
return torch.topk(blk_scores, k=k, dim=-1).indices
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class EmbeddedNSAAttention(nn.Module):
|
| 84 |
+
def __init__(self, dim: int, n_heads: int, n_kv_groups: int, d_k: int, d_v: int,
|
| 85 |
+
l: int, d: int, l_sel: int, n_sel: int, w: int) -> None:
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.n_heads = n_heads
|
| 88 |
+
self.n_kv_groups = n_kv_groups
|
| 89 |
+
self.d_k = d_k
|
| 90 |
+
self.d_v = d_v
|
| 91 |
+
self.l = l
|
| 92 |
+
self.stride = d
|
| 93 |
+
self.l_sel = l_sel
|
| 94 |
+
self.n_sel = n_sel
|
| 95 |
+
self.w = w
|
| 96 |
+
self.W_Q = nn.Linear(dim, n_heads * d_k, bias=False)
|
| 97 |
+
self.W_K_cmp = nn.Linear(dim, n_kv_groups * d_k, bias=False)
|
| 98 |
+
self.W_V_cmp = nn.Linear(dim, n_kv_groups * d_v, bias=False)
|
| 99 |
+
self.W_K_sel = nn.Linear(dim, n_kv_groups * d_k, bias=False)
|
| 100 |
+
self.W_V_sel = nn.Linear(dim, n_kv_groups * d_v, bias=False)
|
| 101 |
+
self.W_K_win = nn.Linear(dim, n_kv_groups * d_k, bias=False)
|
| 102 |
+
self.W_V_win = nn.Linear(dim, n_kv_groups * d_v, bias=False)
|
| 103 |
+
self.g1 = nn.Linear(dim, max(1, dim // 4), bias=False)
|
| 104 |
+
self.g2 = nn.Linear(max(1, dim // 4), 3, bias=False)
|
| 105 |
+
nn.init.zeros_(self.g2.weight)
|
| 106 |
+
self.out = nn.Linear(n_heads * d_v, dim, bias=False)
|
| 107 |
+
|
| 108 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 109 |
+
B, S, D = x.shape
|
| 110 |
+
h, dk, dv = self.n_heads, self.d_k, self.d_v
|
| 111 |
+
Q = self.W_Q(x).view(B, S, h, dk).transpose(1, 2) # [B,h,S,dk]
|
| 112 |
+
g = max(1, self.n_kv_groups)
|
| 113 |
+
r = max(1, h // g)
|
| 114 |
+
# Project per-group K/V then broadcast to heads
|
| 115 |
+
Kc_g = self.W_K_cmp(x).view(B, S, g, dk).permute(0, 2, 1, 3) # [B,g,S,dk]
|
| 116 |
+
Vc_g = self.W_V_cmp(x).view(B, S, g, dv).permute(0, 2, 1, 3)
|
| 117 |
+
Ks_g = self.W_K_sel(x).view(B, S, g, dk).permute(0, 2, 1, 3)
|
| 118 |
+
Vs_g = self.W_V_sel(x).view(B, S, g, dv).permute(0, 2, 1, 3)
|
| 119 |
+
Kw_g = self.W_K_win(x).view(B, S, g, dk).permute(0, 2, 1, 3)
|
| 120 |
+
Vw_g = self.W_V_win(x).view(B, S, g, dv).permute(0, 2, 1, 3)
|
| 121 |
+
# Broadcast groups to heads
|
| 122 |
+
def _bcast_to_heads(T):
|
| 123 |
+
return T.unsqueeze(1).expand(B, r, g, S, T.shape[-1]).reshape(B, h, S, T.shape[-1])
|
| 124 |
+
Kc = _bcast_to_heads(Kc_g)
|
| 125 |
+
Vc = _bcast_to_heads(Vc_g)
|
| 126 |
+
Ks = _bcast_to_heads(Ks_g)
|
| 127 |
+
Vs = _bcast_to_heads(Vs_g)
|
| 128 |
+
Kw = _bcast_to_heads(Kw_g)
|
| 129 |
+
Vw = _bcast_to_heads(Vw_g)
|
| 130 |
+
|
| 131 |
+
# RoPE
|
| 132 |
+
Qr = _rope(Q.transpose(1, 2)).transpose(1, 2)
|
| 133 |
+
Kc_r = _rope(Kc.transpose(1, 2)).transpose(1, 2)
|
| 134 |
+
Ks_r = _rope(Ks.transpose(1, 2)).transpose(1, 2)
|
| 135 |
+
Kw_r = _rope(Kw.transpose(1, 2)).transpose(1, 2)
|
| 136 |
+
|
| 137 |
+
# Compressed: average-pool along time
|
| 138 |
+
Kc_p = _avg_pool_time(Kc_r, kernel=max(1, self.stride), stride=max(1, self.stride))
|
| 139 |
+
Vc_p = _avg_pool_time(Vc, kernel=max(1, self.stride), stride=max(1, self.stride))
|
| 140 |
+
O_cmp = torch.nn.functional.scaled_dot_product_attention(Qr, Kc_p, Vc_p, is_causal=True)
|
| 141 |
+
|
| 142 |
+
# Selection: naive top-n blocks (global), enforce causal via triangular mask
|
| 143 |
+
scores = (Qr * Ks_r).mean(dim=-1) # [B,h,S]
|
| 144 |
+
blk_idx = _selection_blocks(scores, self.l_sel, self.n_sel) # [B,h,n]
|
| 145 |
+
n_blocks = max(1, (S + self.l_sel - 1) // self.l_sel)
|
| 146 |
+
keep = torch.zeros((B, h, n_blocks), device=x.device, dtype=torch.bool)
|
| 147 |
+
keep.scatter_(2, blk_idx, True)
|
| 148 |
+
keep = keep.unsqueeze(-1).expand(B, h, n_blocks, self.l_sel).reshape(B, h, -1)[:, :, :S]
|
| 149 |
+
logits = torch.matmul(Qr / math.sqrt(dk), Ks_r.transpose(-2, -1)) # [B,h,S,S]
|
| 150 |
+
tri = torch.triu(torch.ones((S, S), device=x.device, dtype=torch.bool), diagonal=1)
|
| 151 |
+
logits = logits.masked_fill(tri, float('-inf'))
|
| 152 |
+
sel_mask = torch.where(keep.unsqueeze(2).expand(B, h, S, S), torch.zeros((), device=x.device, dtype=Qr.dtype), torch.full((), float('-inf'), device=x.device, dtype=Qr.dtype))
|
| 153 |
+
P = torch.nn.functional.softmax(logits + sel_mask, dim=-1)
|
| 154 |
+
O_sel = torch.matmul(P, Vs)
|
| 155 |
+
|
| 156 |
+
# Sliding window
|
| 157 |
+
M = _window_mask(Qr, S, max(1, self.w))
|
| 158 |
+
logits_w = torch.matmul(Qr / math.sqrt(dk), Kw_r.transpose(-2, -1)) + M
|
| 159 |
+
P_w = torch.nn.functional.softmax(logits_w, dim=-1)
|
| 160 |
+
O_win = torch.matmul(P_w, Vw)
|
| 161 |
+
|
| 162 |
+
# Gate & mix
|
| 163 |
+
gate = self.g2(torch.nn.functional.silu(self.g1(x))) # [B,S,3]
|
| 164 |
+
gate = torch.nn.functional.softmax(gate, dim=-1)
|
| 165 |
+
gc, gs, gw = gate[..., 0:1], gate[..., 1:2], gate[..., 2:3]
|
| 166 |
+
O = gc.unsqueeze(1) * O_cmp + gs.unsqueeze(1) * O_sel + gw.unsqueeze(1) * O_win
|
| 167 |
+
O = O.transpose(1, 2).reshape(B, S, h * dv)
|
| 168 |
+
return self.out(O)
|
| 169 |
+
|
| 170 |
+
class SimpleAttention(nn.Module):
|
| 171 |
+
def __init__(self, dim: int, n_heads: int, d_k: int, d_v: int) -> None:
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.n_heads = n_heads
|
| 174 |
+
self.d_k = d_k
|
| 175 |
+
self.d_v = d_v
|
| 176 |
+
self.q_proj = nn.Linear(dim, n_heads * d_k, bias=False)
|
| 177 |
+
self.k_proj = nn.Linear(dim, n_heads * d_k, bias=False)
|
| 178 |
+
self.v_proj = nn.Linear(dim, n_heads * d_v, bias=False)
|
| 179 |
+
self.out = nn.Linear(n_heads * d_v, dim, bias=False)
|
| 180 |
+
|
| 181 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
B, S, D = x.shape
|
| 183 |
+
h, dk, dv = self.n_heads, self.d_k, self.d_v
|
| 184 |
+
q = self.q_proj(x).view(B, S, h, dk).transpose(1, 2) # [B,h,S,dk]
|
| 185 |
+
k = self.k_proj(x).view(B, S, h, dk).transpose(1, 2) # [B,h,S,dk]
|
| 186 |
+
v = self.v_proj(x).view(B, S, h, dv).transpose(1, 2) # [B,h,S,dv]
|
| 187 |
+
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
|
| 188 |
+
attn = attn.transpose(1, 2).contiguous().view(B, S, h * dv)
|
| 189 |
+
return self.out(attn)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class SimpleBlock(nn.Module):
|
| 193 |
+
def __init__(self, dim: int, n_heads: int, d_k: int, d_v: int) -> None:
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.norm1 = RMSNorm(dim)
|
| 196 |
+
self.attn = SimpleAttention(dim, n_heads, d_k, d_v)
|
| 197 |
+
self.norm2 = RMSNorm(dim)
|
| 198 |
+
self.mlp = MLP(dim)
|
| 199 |
+
|
| 200 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 201 |
+
x = x + self.attn(self.norm1(x))
|
| 202 |
+
x = x + self.mlp(self.norm2(x))
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class NSABlockRemote(nn.Module):
|
| 207 |
+
"""Transformer block with embedded NSA attention, pre/post RMSNorm, and MLP."""
|
| 208 |
+
def __init__(self, dim: int, n_heads: int, n_kv_groups: int, d_k: int, d_v: int,
|
| 209 |
+
l: int, d: int, l_sel: int, n_sel: int, w: int) -> None:
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.norm1 = RMSNorm(dim)
|
| 212 |
+
self.attn = EmbeddedNSAAttention(dim, n_heads, n_kv_groups, d_k, d_v, l, d, l_sel, n_sel, w)
|
| 213 |
+
self.norm2 = RMSNorm(dim)
|
| 214 |
+
self.mlp = MLP(dim)
|
| 215 |
+
|
| 216 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 217 |
+
x = x + self.attn(self.norm1(x))
|
| 218 |
+
x = x + self.mlp(self.norm2(x))
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
class NSATinyLM(nn.Module):
|
| 222 |
+
def __init__(self, config: NSAConfig):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.config = config
|
| 225 |
+
self.vocab_size = int(config.vocab_size)
|
| 226 |
+
self.hidden_size = int(config.hidden_size)
|
| 227 |
+
self.num_hidden_layers = int(config.num_hidden_layers)
|
| 228 |
+
self.num_attention_heads = int(config.num_attention_heads)
|
| 229 |
+
self.n_kv_groups = int(getattr(config, "n_kv_groups", 1))
|
| 230 |
+
self.d_k = int(getattr(config, "d_k", self.hidden_size // self.num_attention_heads))
|
| 231 |
+
self.d_v = int(getattr(config, "d_v", self.hidden_size // self.num_attention_heads))
|
| 232 |
+
nsa = config.nsa or {}
|
| 233 |
+
self.l = int(nsa.get("block", 32))
|
| 234 |
+
self.d = int(nsa.get("stride", 16))
|
| 235 |
+
self.l_sel = int(nsa.get("sel_block", 64))
|
| 236 |
+
self.n_sel = int(nsa.get("sel_top_n", 16))
|
| 237 |
+
self.w = int(nsa.get("window", 512))
|
| 238 |
+
|
| 239 |
+
self.embed = nn.Embedding(self.vocab_size, self.hidden_size)
|
| 240 |
+
import os as _os
|
| 241 |
+
# Allow forcing simple fallback via env for integration tests
|
| 242 |
+
_force_simple = _os.getenv('NSA_REMOTE_FORCE_SIMPLE', '0').lower() in ('1','true','yes')
|
| 243 |
+
if _force_simple == False:
|
| 244 |
+
self.blocks = nn.ModuleList([
|
| 245 |
+
NSABlockRemote(
|
| 246 |
+
self.hidden_size,
|
| 247 |
+
self.num_attention_heads,
|
| 248 |
+
self.n_kv_groups,
|
| 249 |
+
self.d_k,
|
| 250 |
+
self.d_v,
|
| 251 |
+
self.l,
|
| 252 |
+
self.d,
|
| 253 |
+
self.l_sel,
|
| 254 |
+
self.n_sel,
|
| 255 |
+
self.w,
|
| 256 |
+
) for _ in range(self.num_hidden_layers)
|
| 257 |
+
])
|
| 258 |
+
else:
|
| 259 |
+
self.blocks = nn.ModuleList([
|
| 260 |
+
SimpleBlock(self.hidden_size, self.num_attention_heads, self.d_k, self.d_v)
|
| 261 |
+
for _ in range(self.num_hidden_layers)
|
| 262 |
+
])
|
| 263 |
+
self.norm = nn.LayerNorm(self.hidden_size)
|
| 264 |
+
self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
|
| 265 |
+
|
| 266 |
+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 267 |
+
x = self.embed(input_ids)
|
| 268 |
+
for blk in self.blocks:
|
| 269 |
+
x = blk(x)
|
| 270 |
+
x = self.norm(x)
|
| 271 |
+
logits = self.lm_head(x)
|
| 272 |
+
return logits
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class NSAForCausalLM(PreTrainedModel, GenerationMixin):
|
| 276 |
+
config_class = NSAConfig
|
| 277 |
+
_no_split_modules = ["EmbeddedNSAAttention", "SimpleBlock"]
|
| 278 |
+
|
| 279 |
+
def __init__(self, config: NSAConfig):
|
| 280 |
+
super().__init__(config)
|
| 281 |
+
self.model = NSATinyLM(config)
|
| 282 |
+
self.post_init()
|
| 283 |
+
|
| 284 |
+
def get_input_embeddings(self):
|
| 285 |
+
return self.model.embed
|
| 286 |
+
|
| 287 |
+
def set_input_embeddings(self, new_emb):
|
| 288 |
+
self.model.embed = new_emb
|
| 289 |
+
|
| 290 |
+
def forward(
|
| 291 |
+
self,
|
| 292 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 293 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 294 |
+
labels: Optional[torch.LongTensor] = None,
|
| 295 |
+
**kwargs,
|
| 296 |
+
):
|
| 297 |
+
if input_ids is None:
|
| 298 |
+
raise ValueError("input_ids is required")
|
| 299 |
+
logits = self.model(input_ids)
|
| 300 |
+
loss = None
|
| 301 |
+
if labels is not None:
|
| 302 |
+
# Shift for causal LM loss
|
| 303 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 304 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 305 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
| 306 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 307 |
+
return CausalLMOutput(loss=loss, logits=logits)
|
| 308 |
+
|
| 309 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 310 |
+
# No past_key_values cache: rerun full sequence. Works everywhere, slower at decode.
|
| 311 |
+
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
|
nsa/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
nsa/cache/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
nsa/cache/kv_cache.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from nsa.core.block_index import BlockMeta
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class NSA_KV:
|
| 11 |
+
K_sel: torch.Tensor # [B,G,S,Dk]
|
| 12 |
+
V_sel: torch.Tensor # [B,G,S,Dv]
|
| 13 |
+
K_win: torch.Tensor # [B,G,S,Dk]
|
| 14 |
+
V_win: torch.Tensor # [B,G,S,Dv]
|
| 15 |
+
# raw token-level seq for compressed branch
|
| 16 |
+
K_cmp_raw_seq: torch.Tensor # [B,G,S,Dk]
|
| 17 |
+
V_cmp_raw_seq: torch.Tensor # [B,G,S,Dv]
|
| 18 |
+
K_cmp: torch.Tensor # [B,G,S_cmp,Dk]
|
| 19 |
+
V_cmp: torch.Tensor # [B,G,S_cmp,Dv]
|
| 20 |
+
win_ptr: torch.Tensor # [B,G]
|
| 21 |
+
cmp_emit_next: torch.Tensor # [B,G]
|
| 22 |
+
meta: BlockMeta
|
| 23 |
+
reads_pred: torch.Tensor # [T] per decode step predicted total reads
|
| 24 |
+
reads_act_total: torch.Tensor # [T]
|
| 25 |
+
reads_act_sel: torch.Tensor # [T]
|
| 26 |
+
reads_act_cmp: torch.Tensor # [T]
|
| 27 |
+
reads_act_win: torch.Tensor # [T]
|
| 28 |
+
|
| 29 |
+
def update_selection_raw(self, K: torch.Tensor, V: torch.Tensor) -> None:
|
| 30 |
+
self.K_sel = torch.cat([self.K_sel, K], dim=2)
|
| 31 |
+
self.V_sel = torch.cat([self.V_sel, V], dim=2)
|
| 32 |
+
|
| 33 |
+
def update_window(self, K: torch.Tensor, V: torch.Tensor, w: int) -> None:
|
| 34 |
+
self.K_win = torch.cat([self.K_win, K], dim=2)
|
| 35 |
+
self.V_win = torch.cat([self.V_win, V], dim=2)
|
| 36 |
+
# keep last w tokens
|
| 37 |
+
if self.K_win.shape[2] > w:
|
| 38 |
+
self.K_win = self.K_win[:, :, -w:, :]
|
| 39 |
+
self.V_win = self.V_win[:, :, -w:, :]
|
| 40 |
+
|
| 41 |
+
def update_compressed(
|
| 42 |
+
self, K_raw_cmp: torch.Tensor, V_raw_cmp: torch.Tensor, l: int, d: int
|
| 43 |
+
) -> None:
|
| 44 |
+
# M0 prefill path: rebuild fully using avg-pool ϕ handled upstream
|
| 45 |
+
self.K_cmp = K_raw_cmp
|
| 46 |
+
self.V_cmp = V_raw_cmp
|
| 47 |
+
|
| 48 |
+
def append_cmp_raw(self, K_raw_tok: torch.Tensor, V_raw_tok: torch.Tensor) -> None:
|
| 49 |
+
self.K_cmp_raw_seq = torch.cat([self.K_cmp_raw_seq, K_raw_tok], dim=2)
|
| 50 |
+
self.V_cmp_raw_seq = torch.cat([self.V_cmp_raw_seq, V_raw_tok], dim=2)
|
| 51 |
+
|
| 52 |
+
def append_reads_pred(self, value: int) -> None:
|
| 53 |
+
v = torch.tensor([value], dtype=torch.int64, device=self.K_sel.device)
|
| 54 |
+
self.reads_pred = torch.cat([self.reads_pred, v], dim=0) if self.reads_pred.numel() else v
|
| 55 |
+
|
| 56 |
+
def append_reads_actual(self, total: int, sel: int, cmp: int, win: int) -> None:
|
| 57 |
+
dev = self.K_sel.device
|
| 58 |
+
|
| 59 |
+
def cat_or_set(t: torch.Tensor, val: int) -> torch.Tensor:
|
| 60 |
+
v = torch.tensor([val], dtype=torch.int64, device=dev)
|
| 61 |
+
return torch.cat([t, v], dim=0) if t.numel() else v
|
| 62 |
+
|
| 63 |
+
self.reads_act_total = cat_or_set(self.reads_act_total, total)
|
| 64 |
+
self.reads_act_sel = cat_or_set(self.reads_act_sel, sel)
|
| 65 |
+
self.reads_act_cmp = cat_or_set(self.reads_act_cmp, cmp)
|
| 66 |
+
self.reads_act_win = cat_or_set(self.reads_act_win, win)
|
nsa/core/README.md
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NSA Core Modules — Map and Responsibilities
|
| 2 |
+
|
| 3 |
+
Purpose
|
| 4 |
+
- Quick orientation for contributors. Links to architecture and tests mapping.
|
| 5 |
+
|
| 6 |
+
Modules
|
| 7 |
+
- `nsa_attention.py`: Top‑level attention module. Branch wiring (cmp/sel/win), gate MLP (τ=1.0, zero‑init last layer), strict masks, decode caches (`K_sel/V_sel`, `K_win/V_win`), counters.
|
| 8 |
+
- `selection_scorer.py`: Selection pipeline — compute p_cmp, map to p_slc (Eq.9 CSR), group reduce (Eq.10), deterministic top‑n, range construction (v2 vectorized), NVTX tags.
|
| 9 |
+
- `block_index.py`: CSR for cmp→sel fractional overlaps, conversions (CSR↔COO), helpers.
|
| 10 |
+
- `compress_pool.py`: Compressed branch pooling ϕ, emission schedule (warmup l, stride d), RoPE ordering.
|
| 11 |
+
- `attention_kernels.py`: SDPA variants — packed selection, masked SDPA, varlen helpers; FA‑2 wrappers (opt‑in) for cmp/win.
|
| 12 |
+
- `packing.py`: Range packing, index normalization, adjacency merge/de‑dup.
|
| 13 |
+
- `rope.py`: RoPE application for Q and per‑branch K before ϕ.
|
| 14 |
+
- `flags.py`: Environment flags and routing toggles.
|
| 15 |
+
- `debug.py`, `collate.py`: Debug helpers and varlen collate utilities.
|
| 16 |
+
|
| 17 |
+
Key Invariants (guarded by tests)
|
| 18 |
+
- Strict causality masks (see `nsa/tests/test_masks.py`).
|
| 19 |
+
- Group consistency (Eq.10) (see `nsa/tests/test_group_consistency*.py`).
|
| 20 |
+
- Selection rules (tie‑break, merge/de‑dup/clamp) (see `nsa/tests/test_selection_*`, `test_ranges_normalization.py`).
|
| 21 |
+
- Decode reads counters formula (see `nsa/tests/test_decode_counters.py`).
|
| 22 |
+
|
| 23 |
+
References
|
| 24 |
+
- Architecture Overview: Documentation/Architecture/Overview.md
|
| 25 |
+
- Selection Semantics: Documentation/Architecture/Selection-Semantics.md
|
| 26 |
+
- Tests Index: Documentation/Tests/Index.md
|
| 27 |
+
|
nsa/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
nsa/core/attention_kernels.py
ADDED
|
@@ -0,0 +1,1403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from typing import Dict, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from nsa.core.debug import log
|
| 10 |
+
from nsa.core.packing import (
|
| 11 |
+
build_cu_seqlens_for_buckets,
|
| 12 |
+
build_length_buckets,
|
| 13 |
+
compute_compressed_lengths,
|
| 14 |
+
compute_sliding_lengths,
|
| 15 |
+
)
|
| 16 |
+
from nsa.kernels.flash_wrappers import (
|
| 17 |
+
attention_bgh,
|
| 18 |
+
attention_fa2_dense_batch,
|
| 19 |
+
attention_fa2_varlen,
|
| 20 |
+
fa2_supported,
|
| 21 |
+
fa2_supported_verbose,
|
| 22 |
+
is_flash_varlen_available,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Simple grow-on-demand workspaces for varlen packing to avoid frequent allocations
|
| 26 |
+
_VARLEN_WS: Dict[Tuple, Dict[str, torch.Tensor]] = {}
|
| 27 |
+
_SEL_PACK_WS: Dict[Tuple, Dict[str, torch.Tensor]] = {}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _env_int(name: str, default: int) -> int:
|
| 31 |
+
try:
|
| 32 |
+
v = int(os.getenv(name, str(default)))
|
| 33 |
+
return v
|
| 34 |
+
except Exception:
|
| 35 |
+
return default
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _env_int_bounded(name: str, default: int, min_val: int = 0, max_val: int = 10**8) -> int:
|
| 39 |
+
"""Read integer from environment with bounds checking to prevent excessive memory allocation."""
|
| 40 |
+
try:
|
| 41 |
+
v = int(os.getenv(name, str(default)))
|
| 42 |
+
if v < min_val:
|
| 43 |
+
return min_val
|
| 44 |
+
if v > max_val:
|
| 45 |
+
# Log warning if value exceeds max
|
| 46 |
+
import warnings
|
| 47 |
+
|
| 48 |
+
warnings.warn(f"{name}={v} exceeds maximum {max_val}, clamping to {max_val}")
|
| 49 |
+
return max_val
|
| 50 |
+
return v
|
| 51 |
+
except Exception:
|
| 52 |
+
return default
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def clear_varlen_workspaces() -> None:
|
| 56 |
+
"""Optional memory cleanup: free varlen packing workspaces."""
|
| 57 |
+
_VARLEN_WS.clear()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def clear_selection_pack_workspaces() -> None:
|
| 61 |
+
"""Optional memory cleanup: free selection pack workspaces."""
|
| 62 |
+
_SEL_PACK_WS.clear()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _get_varlen_workspace(
|
| 66 |
+
device: torch.device,
|
| 67 |
+
dtype_q: torch.dtype,
|
| 68 |
+
dtype_k: torch.dtype,
|
| 69 |
+
dtype_v: torch.dtype,
|
| 70 |
+
h: int,
|
| 71 |
+
d_k: int,
|
| 72 |
+
d_v: int,
|
| 73 |
+
cap_N: int,
|
| 74 |
+
cap_total_k: int,
|
| 75 |
+
) -> dict[str, torch.Tensor]:
|
| 76 |
+
key = (str(device), dtype_q, dtype_k, dtype_v, h, d_k, d_v)
|
| 77 |
+
ws = _VARLEN_WS.get(key)
|
| 78 |
+
need_new = ws is None
|
| 79 |
+
if not need_new:
|
| 80 |
+
q, k, v = ws["q"], ws["k"], ws["v"]
|
| 81 |
+
cuq, cuk = ws["cuq"], ws["cuk"]
|
| 82 |
+
need_new = (
|
| 83 |
+
q.shape[0] < cap_N
|
| 84 |
+
or k.shape[0] < cap_total_k
|
| 85 |
+
or v.shape[0] < cap_total_k
|
| 86 |
+
or cuq.numel() < (cap_N + 1)
|
| 87 |
+
or cuk.numel() < (cap_N + 1)
|
| 88 |
+
)
|
| 89 |
+
if need_new:
|
| 90 |
+
# Allow pre-sizing via env to avoid growth reallocations on long runs
|
| 91 |
+
# Bounded to prevent excessive memory allocation (max 1M rows, 100M total K/V)
|
| 92 |
+
reserve_N = _env_int_bounded("NSA_VARLEN_RESERVE_N", 0, 0, 10**6)
|
| 93 |
+
reserve_K = _env_int_bounded("NSA_VARLEN_RESERVE_K", 0, 0, 10**8)
|
| 94 |
+
new_N = max(cap_N, reserve_N, 1)
|
| 95 |
+
new_K = max(cap_total_k, reserve_K, 1)
|
| 96 |
+
ws = {
|
| 97 |
+
"q": torch.empty((new_N, h, d_k), dtype=dtype_q, device=device),
|
| 98 |
+
"k": torch.empty((new_K, h, d_k), dtype=dtype_k, device=device),
|
| 99 |
+
"v": torch.empty((new_K, h, d_v), dtype=dtype_v, device=device),
|
| 100 |
+
"cuq": torch.empty((new_N + 1,), dtype=torch.int32, device=device),
|
| 101 |
+
"cuk": torch.empty((new_N + 1,), dtype=torch.int32, device=device),
|
| 102 |
+
}
|
| 103 |
+
_VARLEN_WS[key] = ws
|
| 104 |
+
return ws
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def batched_causal_attention_compressed(
|
| 108 |
+
Q: torch.Tensor, # [B,S,G,h,Dk]
|
| 109 |
+
K_cmp: torch.Tensor, # [B,G,S_cmp,Dk]
|
| 110 |
+
V_cmp: torch.Tensor, # [B,G,S_cmp,Dv]
|
| 111 |
+
l: int,
|
| 112 |
+
d: int,
|
| 113 |
+
) -> torch.Tensor: # [B,S,G,h,Dv]
|
| 114 |
+
"""
|
| 115 |
+
Compressed branch attention with per-row causal mask derived from emission schedule.
|
| 116 |
+
We cannot rely on is_causal=True due to S_q != S_kv and variable allowed lengths per t.
|
| 117 |
+
"""
|
| 118 |
+
B, S, G, h, Dk = Q.shape
|
| 119 |
+
S_cmp = K_cmp.shape[2]
|
| 120 |
+
device = Q.device
|
| 121 |
+
|
| 122 |
+
# num_cmp(t) = 0 if t+1 < l else floor((t+1 - l) / d) + 1, clamped to S_cmp
|
| 123 |
+
tpos = torch.arange(S, device=device)
|
| 124 |
+
num_cmp = torch.where(tpos + 1 < l, 0, ((tpos + 1 - l) // d) + 1).clamp(max=S_cmp)
|
| 125 |
+
col = torch.arange(S_cmp, device=device).view(1, S_cmp)
|
| 126 |
+
# disallowed mask: True means masked
|
| 127 |
+
col >= num_cmp.view(S, 1) # [S,S_cmp]
|
| 128 |
+
# Enforce token-level causality as well: no compressed tokens emitted from future blocks beyond t
|
| 129 |
+
# When l=d=1, S_cmp == S and this reduces to standard causal
|
| 130 |
+
|
| 131 |
+
# Parity-first: exact per-t using attention_bgh
|
| 132 |
+
out = torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
|
| 133 |
+
log("cmp.begin", B=B, S=S, S_cmp=int(S_cmp), l=l, d=d)
|
| 134 |
+
for t in range(S):
|
| 135 |
+
L = int(num_cmp[t].item())
|
| 136 |
+
if L <= 0:
|
| 137 |
+
out[:, t] = 0.0
|
| 138 |
+
continue
|
| 139 |
+
q_t = Q[:, t]
|
| 140 |
+
k_t = K_cmp[:, :, :L, :]
|
| 141 |
+
v_t = V_cmp[:, :, :L, :]
|
| 142 |
+
out[:, t] = attention_bgh(q_t, k_t, v_t, causal=True)
|
| 143 |
+
log("cmp.step", t=int(t), L=L)
|
| 144 |
+
return out
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def sliding_window_attention(
|
| 148 |
+
Q: torch.Tensor, # [B,S,G,h,Dk]
|
| 149 |
+
K: torch.Tensor, # [B,G,S,Dk]
|
| 150 |
+
V: torch.Tensor, # [B,G,S,Dv]
|
| 151 |
+
w: int,
|
| 152 |
+
) -> torch.Tensor: # [B,S,G,h,Dv]
|
| 153 |
+
B, S, G, h, Dk = Q.shape
|
| 154 |
+
# Empty or zero window → zeros
|
| 155 |
+
if w <= 0 or K.shape[2] == 0 or S == 0:
|
| 156 |
+
return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device)
|
| 157 |
+
device = Q.device
|
| 158 |
+
# Build banded causal mask once: allowed keys per row t are [t-w+1 .. t]
|
| 159 |
+
row = torch.arange(S, device=device).view(S, 1)
|
| 160 |
+
col = torch.arange(S, device=device).view(1, S)
|
| 161 |
+
allowed = (col <= row) & (col >= (row - (w - 1))) # [S,S]
|
| 162 |
+
# Use additive float mask with -inf for disallowed positions to avoid NaNs
|
| 163 |
+
# across SDPA backends/dtypes. Shape: [S,S] then broadcast to [B,G*h,S,S].
|
| 164 |
+
Mf2d = torch.full((S, S), float("-inf"), dtype=Q.dtype, device=device)
|
| 165 |
+
Mf2d.masked_fill_(allowed, 0.0)
|
| 166 |
+
# Prepare SDPA tensors: [B, G*h, S, D*]
|
| 167 |
+
Qf = Q.reshape(B, S, G * h, Dk).transpose(1, 2).contiguous() # [B,G*h,S,Dk]
|
| 168 |
+
Kf = K.unsqueeze(2).expand(B, G, h, S, Dk).reshape(B, G * h, S, Dk).contiguous()
|
| 169 |
+
Vf = (
|
| 170 |
+
V.unsqueeze(2)
|
| 171 |
+
.expand(B, G, h, S, V.shape[-1])
|
| 172 |
+
.reshape(B, G * h, S, V.shape[-1])
|
| 173 |
+
.contiguous()
|
| 174 |
+
)
|
| 175 |
+
# Broadcast additive mask to [B,G*h,S,S]
|
| 176 |
+
Mf = Mf2d.view(1, 1, S, S).expand(B, G * h, S, S)
|
| 177 |
+
Of = F.scaled_dot_product_attention(Qf, Kf, Vf, attn_mask=Mf) # [B,G*h,S,Dv]
|
| 178 |
+
Of = Of.transpose(1, 2).reshape(B, S, G, h, V.shape[-1])
|
| 179 |
+
return Of
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def grouped_selection_attention(
|
| 183 |
+
Q: torch.Tensor, # [B,S,G,h,Dk]
|
| 184 |
+
K: torch.Tensor, # [B,G,S_kv,Dk]
|
| 185 |
+
V: torch.Tensor, # [B,G,S_kv,Dv]
|
| 186 |
+
ranges: torch.Tensor, # [B,S,G,n,2]
|
| 187 |
+
) -> torch.Tensor: # [B,S,G,h,Dv]
|
| 188 |
+
B, S, G, h, Dk = Q.shape
|
| 189 |
+
K.shape[2]
|
| 190 |
+
|
| 191 |
+
# Path 1: exact sequential-equivalence gather per (b,t,g)
|
| 192 |
+
out = torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device)
|
| 193 |
+
for b in range(B):
|
| 194 |
+
for t in range(S):
|
| 195 |
+
for g in range(G):
|
| 196 |
+
# build exact gather index list
|
| 197 |
+
idxs = []
|
| 198 |
+
for i in range(ranges.shape[3]):
|
| 199 |
+
s0 = int(ranges[b, t, g, i, 0].item())
|
| 200 |
+
e0 = int(ranges[b, t, g, i, 1].item())
|
| 201 |
+
if e0 > s0:
|
| 202 |
+
idxs.append(torch.arange(s0, e0, device=V.device))
|
| 203 |
+
if idxs:
|
| 204 |
+
idx = torch.cat(idxs)
|
| 205 |
+
k = K[b, g, idx] # [L,Dk]
|
| 206 |
+
v = V[b, g, idx] # [L,Dv]
|
| 207 |
+
q = Q[b, t, g] # [h,Dk]
|
| 208 |
+
# Expand per-head kv and add query-length dim for SDPA
|
| 209 |
+
q_btgh = q.unsqueeze(0).unsqueeze(2) # [1,h,1,Dk]
|
| 210 |
+
k_btgh = (
|
| 211 |
+
k.unsqueeze(0).unsqueeze(0).expand(1, q.shape[0], k.shape[0], k.shape[1])
|
| 212 |
+
) # [1,h,L,Dk]
|
| 213 |
+
v_btgh = (
|
| 214 |
+
v.unsqueeze(0).unsqueeze(0).expand(1, q.shape[0], v.shape[0], v.shape[1])
|
| 215 |
+
) # [1,h,L,Dv]
|
| 216 |
+
q_btgh = q_btgh.contiguous()
|
| 217 |
+
k_btgh = k_btgh.contiguous()
|
| 218 |
+
v_btgh = v_btgh.contiguous()
|
| 219 |
+
attn = F.scaled_dot_product_attention(
|
| 220 |
+
q_btgh, k_btgh, v_btgh, is_causal=True
|
| 221 |
+
) # [1,h,1,Dv]
|
| 222 |
+
out[b, t, g] = attn.squeeze(0).squeeze(1) # [h,Dv]
|
| 223 |
+
log("sel.step", b=int(b), t=int(t), g=int(g), L=int(k.shape[0]))
|
| 224 |
+
else:
|
| 225 |
+
out[b, t, g] = 0.0
|
| 226 |
+
log("sel.step", b=int(b), t=int(t), g=int(g), L=0)
|
| 227 |
+
return out
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def sliding_window_attention_masked(
|
| 231 |
+
Q: torch.Tensor, # [B,S,G,h,Dk]
|
| 232 |
+
K: torch.Tensor, # [B,G,S,Dk]
|
| 233 |
+
V: torch.Tensor, # [B,G,S,Dv]
|
| 234 |
+
w: int,
|
| 235 |
+
) -> torch.Tensor: # [B,S,G,h,Dv]
|
| 236 |
+
# Memory-friendly masked semantics: only the first element in [start..t] is attended.
|
| 237 |
+
# With a single allowed key per row, SDPA reduces to returning that V directly.
|
| 238 |
+
B, S, G, h, Dk = Q.shape
|
| 239 |
+
if w <= 0 or K.shape[2] == 0:
|
| 240 |
+
return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=V.device)
|
| 241 |
+
device = Q.device
|
| 242 |
+
tpos = torch.arange(S, device=device)
|
| 243 |
+
start = (tpos - (w - 1)).clamp_min(0) # [S]
|
| 244 |
+
# Build per-(B,G,S) gather indices and fetch V at start
|
| 245 |
+
idx = start.view(1, 1, S, 1).expand(B, G, S, 1) # [B,G,S,1]
|
| 246 |
+
v_sel = torch.gather(V, 2, idx.expand(B, G, S, V.shape[-1])) # [B,G,S,Dv]
|
| 247 |
+
# Expand across heads; result [B,S,G,h,Dv]
|
| 248 |
+
Of = v_sel.permute(0, 2, 1, 3).unsqueeze(3).expand(B, S, G, h, V.shape[-1])
|
| 249 |
+
return Of
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def batched_causal_attention_compressed_masked(
|
| 253 |
+
Q: torch.Tensor, # [B,S,G,h,Dk]
|
| 254 |
+
K_cmp: torch.Tensor, # [B,G,S_cmp,Dk]
|
| 255 |
+
V_cmp: torch.Tensor, # [B,G,S_cmp,Dv]
|
| 256 |
+
l: int,
|
| 257 |
+
d: int,
|
| 258 |
+
) -> torch.Tensor: # [B,S,G,h,Dv]
|
| 259 |
+
# Memory-friendly masked semantics: if num_cmp(t)>0, attend only to index 0 → return V[:, :, 0].
|
| 260 |
+
B, S, G, h, Dk = Q.shape
|
| 261 |
+
S_cmp = K_cmp.shape[2]
|
| 262 |
+
device = Q.device
|
| 263 |
+
if S_cmp == 0:
|
| 264 |
+
return torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
|
| 265 |
+
tpos = torch.arange(S, device=device)
|
| 266 |
+
num_cmp = torch.where(tpos + 1 < l, 0, ((tpos + 1 - l) // d) + 1).clamp(min=0, max=S_cmp) # [S]
|
| 267 |
+
have_any = (num_cmp > 0).view(1, S, 1, 1, 1).expand(B, S, G, h, 1)
|
| 268 |
+
v0 = V_cmp[:, :, 0, :] # [B,G,Dv]
|
| 269 |
+
v0f = v0.unsqueeze(1).unsqueeze(3).expand(B, S, G, h, V_cmp.shape[-1])
|
| 270 |
+
Of = torch.where(have_any, v0f, torch.zeros_like(v0f))
|
| 271 |
+
return Of
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def grouped_selection_attention_packed(
|
| 275 |
+
Q: torch.Tensor, # [B,S,G,h,Dk]
|
| 276 |
+
K: torch.Tensor, # [B,G,S_kv,Dk]
|
| 277 |
+
V: torch.Tensor, # [B,G,S_kv,Dv]
|
| 278 |
+
ranges: torch.Tensor, # [B,S,G,n,2]
|
| 279 |
+
) -> torch.Tensor: # [B,S,G,h,Dv]
|
| 280 |
+
"""
|
| 281 |
+
Bucketed varlen packing by row length L with parity to gather path.
|
| 282 |
+
For each (b,t,g), build its flat index list from ranges, bucket rows
|
| 283 |
+
by identical L, and run one SDPA per bucket.
|
| 284 |
+
"""
|
| 285 |
+
B, S, G, h, Dk = Q.shape
|
| 286 |
+
K.shape[2]
|
| 287 |
+
device = Q.device
|
| 288 |
+
# Initialize output
|
| 289 |
+
out = torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=device)
|
| 290 |
+
# Flatten to row indices
|
| 291 |
+
rows = [] # list of (b,t,g, idx_tensor[L])
|
| 292 |
+
lengths = []
|
| 293 |
+
for b in range(B):
|
| 294 |
+
for t in range(S):
|
| 295 |
+
for g in range(G):
|
| 296 |
+
idxs = []
|
| 297 |
+
for i in range(ranges.shape[3]):
|
| 298 |
+
s0 = int(ranges[b, t, g, i, 0].item())
|
| 299 |
+
e0 = int(ranges[b, t, g, i, 1].item())
|
| 300 |
+
if e0 > s0:
|
| 301 |
+
idxs.append(torch.arange(s0, e0, device=device))
|
| 302 |
+
if idxs:
|
| 303 |
+
idx = torch.cat(idxs)
|
| 304 |
+
else:
|
| 305 |
+
idx = torch.empty((0,), dtype=torch.long, device=device)
|
| 306 |
+
rows.append((b, t, g, idx))
|
| 307 |
+
lengths.append(idx.numel())
|
| 308 |
+
if not rows:
|
| 309 |
+
return out
|
| 310 |
+
lengths_t = torch.tensor(lengths, device=device)
|
| 311 |
+
unique_L = torch.unique(lengths_t)
|
| 312 |
+
# Enable autograd-safe packing during training or when forced by env
|
| 313 |
+
use_safe_pack = (
|
| 314 |
+
torch.is_grad_enabled() and (Q.requires_grad or K.requires_grad or V.requires_grad)
|
| 315 |
+
) or _env_bool("NSA_TRAIN_SAFE_PACK", False)
|
| 316 |
+
|
| 317 |
+
for Lval in unique_L.tolist():
|
| 318 |
+
L = int(Lval)
|
| 319 |
+
# collect row indices for this bucket
|
| 320 |
+
bucket_idx = [i for i, Lx in enumerate(lengths) if Lx == L]
|
| 321 |
+
if L == 0 or len(bucket_idx) == 0:
|
| 322 |
+
# rows with L=0 remain zeros
|
| 323 |
+
continue
|
| 324 |
+
N = len(bucket_idx)
|
| 325 |
+
if use_safe_pack:
|
| 326 |
+
# Graph-friendly packing using stack to preserve autograd links
|
| 327 |
+
map_rows = []
|
| 328 |
+
Q_list = []
|
| 329 |
+
K_list = []
|
| 330 |
+
V_list = []
|
| 331 |
+
for ridx in bucket_idx:
|
| 332 |
+
b, t, g, idx = rows[ridx]
|
| 333 |
+
map_rows.append((b, t, g))
|
| 334 |
+
Q_list.append(Q[b, t, g]) # [h,Dk]
|
| 335 |
+
K_list.append(K[b, g, idx]) # [L,Dk]
|
| 336 |
+
V_list.append(V[b, g, idx]) # [L,Dv]
|
| 337 |
+
Qb = torch.stack(Q_list, dim=0) # [N,h,Dk]
|
| 338 |
+
Kb = torch.stack(K_list, dim=0) # [N,L,Dk]
|
| 339 |
+
Vb = torch.stack(V_list, dim=0) # [N,L,Dv]
|
| 340 |
+
q_btgh = Qb.unsqueeze(1).permute(0, 2, 1, 3) # [N,h,1,Dk]
|
| 341 |
+
k_btgh = Kb.unsqueeze(1).expand(N, h, L, Dk)
|
| 342 |
+
v_btgh = Vb.unsqueeze(1).expand(N, h, L, V.shape[-1])
|
| 343 |
+
attn = F.scaled_dot_product_attention(q_btgh, k_btgh, v_btgh, is_causal=True)
|
| 344 |
+
Ob = attn.squeeze(2) # [N,h,Dv]
|
| 345 |
+
for j, (b, t, g) in enumerate(map_rows):
|
| 346 |
+
out[b, t, g] = Ob[j]
|
| 347 |
+
else:
|
| 348 |
+
# Workspace-backed Q, K, V batches to reduce allocations
|
| 349 |
+
ws_key = (str(device), Q.dtype, K.dtype, V.dtype, h, Dk, V.shape[-1])
|
| 350 |
+
ws = _SEL_PACK_WS.get(ws_key)
|
| 351 |
+
need_new = (
|
| 352 |
+
ws is None or ws["Q"].shape[0] < N or ws["K"].shape[1] < L or ws["V"].shape[1] < L
|
| 353 |
+
)
|
| 354 |
+
if need_new:
|
| 355 |
+
# Allow pre-sizing via env to reduce reallocations
|
| 356 |
+
# Bounded to prevent excessive memory allocation (max 100K rows, 10K length)
|
| 357 |
+
reserve_N = _env_int_bounded("NSA_SEL_PACK_RESERVE_N", 0, 0, 10**5)
|
| 358 |
+
reserve_L = _env_int_bounded("NSA_SEL_PACK_RESERVE_L", 0, 0, 10**4)
|
| 359 |
+
new_N = max(N, reserve_N)
|
| 360 |
+
new_L = max(L, reserve_L)
|
| 361 |
+
Qb = torch.empty((new_N, h, Dk), dtype=Q.dtype, device=device)
|
| 362 |
+
Kb = torch.empty((new_N, new_L, Dk), dtype=K.dtype, device=device)
|
| 363 |
+
Vb = torch.empty((new_N, new_L, V.shape[-1]), dtype=V.dtype, device=device)
|
| 364 |
+
_SEL_PACK_WS[ws_key] = {"Q": Qb, "K": Kb, "V": Vb}
|
| 365 |
+
else:
|
| 366 |
+
Qb = _SEL_PACK_WS[ws_key]["Q"][:N]
|
| 367 |
+
Kb = _SEL_PACK_WS[ws_key]["K"][:N, :L]
|
| 368 |
+
Vb = _SEL_PACK_WS[ws_key]["V"][:N, :L]
|
| 369 |
+
# Populate workspace buffers and perform SDPA (execute for both new and reused workspaces)
|
| 370 |
+
map_rows = []
|
| 371 |
+
for j, ridx in enumerate(bucket_idx):
|
| 372 |
+
b, t, g, idx = rows[ridx]
|
| 373 |
+
Qb[j] = Q[b, t, g] # [h,Dk]
|
| 374 |
+
Kb[j] = K[b, g, idx] # [L,Dk]
|
| 375 |
+
Vb[j] = V[b, g, idx] # [L,Dv]
|
| 376 |
+
map_rows.append((b, t, g))
|
| 377 |
+
# SDPA per bucket: expand per-head
|
| 378 |
+
q_btgh = Qb.unsqueeze(1) # [N,1,h,Dk]
|
| 379 |
+
q_btgh = q_btgh.permute(0, 2, 1, 3) # [N,h,1,Dk]
|
| 380 |
+
k_btgh = Kb.unsqueeze(1).expand(N, h, L, Dk)
|
| 381 |
+
v_btgh = Vb.unsqueeze(1).expand(N, h, L, V.shape[-1])
|
| 382 |
+
attn = F.scaled_dot_product_attention(
|
| 383 |
+
q_btgh, k_btgh, v_btgh, is_causal=True
|
| 384 |
+
) # [N,h,1,Dv]
|
| 385 |
+
Ob = attn.squeeze(2) # [N,h,Dv]
|
| 386 |
+
# Scatter back
|
| 387 |
+
for j, (b, t, g) in enumerate(map_rows):
|
| 388 |
+
out[b, t, g] = Ob[j]
|
| 389 |
+
return out
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def selection_attention_varlen_all(
|
| 393 |
+
Q: torch.Tensor, # [B,S,G,h,Dk]
|
| 394 |
+
K: torch.Tensor, # [B,G,S_kv,Dk]
|
| 395 |
+
V: torch.Tensor, # [B,G,S_kv,Dv]
|
| 396 |
+
ranges: torch.Tensor, # [B,S,G,n,2]
|
| 397 |
+
) -> torch.Tensor: # [B,S,G,h,Dv]
|
| 398 |
+
"""
|
| 399 |
+
Fully batched selection attention using varlen packing across all (B,S,G) rows.
|
| 400 |
+
|
| 401 |
+
If NSA_SEL_VARLEN_V2 is enabled (default), dispatches to the vectorized v2
|
| 402 |
+
packer. Otherwise uses the legacy v1 path (minimal loops with workspace).
|
| 403 |
+
"""
|
| 404 |
+
# Optional v2 vectorized packer
|
| 405 |
+
if os.getenv("NSA_SEL_VARLEN_V2", "1").lower() in ("1", "true", "yes", "on"):
|
| 406 |
+
return selection_attention_varlen_all_v2(Q, K, V, ranges)
|
| 407 |
+
B, S, G, h, Dk = Q.shape
|
| 408 |
+
# Parity override: when enabled, force causal=True to match packed reference
|
| 409 |
+
_parity = os.getenv("NSA_SEL_VARLEN_FORCE_PARITY", "0").lower() in ("1", "true", "yes", "on")
|
| 410 |
+
if _parity:
|
| 411 |
+
# Force exact parity by delegating to the packed reference
|
| 412 |
+
return grouped_selection_attention_packed(Q, K, V, ranges)
|
| 413 |
+
device = Q.device
|
| 414 |
+
Dv = V.shape[-1]
|
| 415 |
+
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
|
| 416 |
+
# Build row list and lengths from ranges (sum of segment lengths)
|
| 417 |
+
rows: list[tuple[int, int, int]] = []
|
| 418 |
+
lens: list[int] = []
|
| 419 |
+
for b in range(B):
|
| 420 |
+
for t in range(S):
|
| 421 |
+
for g in range(G):
|
| 422 |
+
L = 0
|
| 423 |
+
for i in range(ranges.shape[3]):
|
| 424 |
+
s0 = int(ranges[b, t, g, i, 0].item())
|
| 425 |
+
e0 = int(ranges[b, t, g, i, 1].item())
|
| 426 |
+
if e0 > s0:
|
| 427 |
+
L += e0 - s0
|
| 428 |
+
if L > 0:
|
| 429 |
+
rows.append((b, t, g))
|
| 430 |
+
lens.append(L)
|
| 431 |
+
N = len(rows)
|
| 432 |
+
if N == 0:
|
| 433 |
+
return out
|
| 434 |
+
|
| 435 |
+
total_k = int(sum(lens))
|
| 436 |
+
# Workspace-backed packing
|
| 437 |
+
ws = _get_varlen_workspace(
|
| 438 |
+
device,
|
| 439 |
+
dtype_q=Q.dtype,
|
| 440 |
+
dtype_k=K.dtype,
|
| 441 |
+
dtype_v=V.dtype,
|
| 442 |
+
h=h,
|
| 443 |
+
d_k=Dk,
|
| 444 |
+
d_v=Dv,
|
| 445 |
+
cap_N=N,
|
| 446 |
+
cap_total_k=total_k,
|
| 447 |
+
)
|
| 448 |
+
q_pack = ws["q"][:N]
|
| 449 |
+
k_pack = ws["k"][:total_k]
|
| 450 |
+
v_pack = ws["v"][:total_k]
|
| 451 |
+
cuq = ws["cuq"][: N + 1]
|
| 452 |
+
cuk = ws["cuk"][: N + 1]
|
| 453 |
+
# Fill cu_seqlens
|
| 454 |
+
cuq.zero_()
|
| 455 |
+
cuk.zero_()
|
| 456 |
+
# Pack per row
|
| 457 |
+
write_pos = 0
|
| 458 |
+
for i, (b, t, g) in enumerate(rows):
|
| 459 |
+
# q for row
|
| 460 |
+
q_pack[i] = Q[b, t, g]
|
| 461 |
+
# iterate segments for this row
|
| 462 |
+
for j in range(ranges.shape[3]):
|
| 463 |
+
s0 = int(ranges[b, t, g, j, 0].item())
|
| 464 |
+
e0 = int(ranges[b, t, g, j, 1].item())
|
| 465 |
+
if e0 <= s0:
|
| 466 |
+
continue
|
| 467 |
+
seg_k = K[b, g, s0:e0] # [Lseg,Dk]
|
| 468 |
+
seg_v = V[b, g, s0:e0] # [Lseg,Dv]
|
| 469 |
+
Lseg = e0 - s0
|
| 470 |
+
# Assign using explicit expand_as to match target slice shape and avoid view pitfalls
|
| 471 |
+
_kslice = k_pack[write_pos : write_pos + Lseg]
|
| 472 |
+
_vslice = v_pack[write_pos : write_pos + Lseg]
|
| 473 |
+
_kslice.copy_(seg_k[:, None, :].expand_as(_kslice))
|
| 474 |
+
_vslice.copy_(seg_v[:, None, :].expand_as(_vslice))
|
| 475 |
+
write_pos += Lseg
|
| 476 |
+
cuq[i + 1] = cuq[i] + 1
|
| 477 |
+
cuk[i + 1] = cuk[i] + lens[i]
|
| 478 |
+
# Try FA‑2 varlen if available and supported. Default non-causal semantics;
|
| 479 |
+
# optionally force parity with packed path via NSA_SEL_VARLEN_FORCE_PARITY.
|
| 480 |
+
ok, _ = fa2_supported_verbose(device, Q.dtype, Dk)
|
| 481 |
+
if ok and is_flash_varlen_available():
|
| 482 |
+
try:
|
| 483 |
+
o_pack = attention_fa2_varlen(
|
| 484 |
+
q_pack,
|
| 485 |
+
k_pack,
|
| 486 |
+
v_pack,
|
| 487 |
+
cuq,
|
| 488 |
+
cuk,
|
| 489 |
+
max_seqlen_q=1,
|
| 490 |
+
max_seqlen_k=max(lens),
|
| 491 |
+
causal=_parity,
|
| 492 |
+
) # [N,h,Dv]
|
| 493 |
+
# Scatter back
|
| 494 |
+
for i, (b, t, g) in enumerate(rows):
|
| 495 |
+
out[b, t, g] = o_pack[i]
|
| 496 |
+
return out
|
| 497 |
+
except Exception:
|
| 498 |
+
pass
|
| 499 |
+
# Dense batch per fixed L bucket as fallback
|
| 500 |
+
buckets: dict[int, list[int]] = {}
|
| 501 |
+
for i, L in enumerate(lens):
|
| 502 |
+
buckets.setdefault(L, []).append(i)
|
| 503 |
+
for L, idxs in buckets.items():
|
| 504 |
+
if L <= 0 or len(idxs) == 0:
|
| 505 |
+
continue
|
| 506 |
+
Nb = len(idxs)
|
| 507 |
+
Qb = torch.empty((Nb, h, Dk), dtype=Q.dtype, device=device)
|
| 508 |
+
Kb = torch.empty((Nb, L, Dk), dtype=K.dtype, device=device)
|
| 509 |
+
Vb = torch.empty((Nb, L, Dv), dtype=V.dtype, device=device)
|
| 510 |
+
tgt: list[tuple[int, int, int]] = []
|
| 511 |
+
for j, irow in enumerate(idxs):
|
| 512 |
+
b, t, g = rows[irow]
|
| 513 |
+
Qb[j] = Q[b, t, g]
|
| 514 |
+
# Rebuild fixed-length K/V for this row from ranges
|
| 515 |
+
write = 0
|
| 516 |
+
for rj in range(ranges.shape[3]):
|
| 517 |
+
s0 = int(ranges[b, t, g, rj, 0].item())
|
| 518 |
+
e0 = int(ranges[b, t, g, rj, 1].item())
|
| 519 |
+
if e0 <= s0:
|
| 520 |
+
continue
|
| 521 |
+
Lseg = e0 - s0
|
| 522 |
+
Kb[j, write : write + Lseg] = K[b, g, s0:e0]
|
| 523 |
+
Vb[j, write : write + Lseg] = V[b, g, s0:e0]
|
| 524 |
+
write += Lseg
|
| 525 |
+
tgt.append((b, t, g))
|
| 526 |
+
# Batched dense fallback for this bucket. Default non-causal; optionally force parity.
|
| 527 |
+
try:
|
| 528 |
+
q_rows = Qb.unsqueeze(1) # [Nb,1,h,Dk]
|
| 529 |
+
k_rows = Kb.unsqueeze(2).expand(Nb, L, h, Dk) # [Nb,L,h,Dk]
|
| 530 |
+
v_rows = Vb.unsqueeze(2).expand(Nb, L, h, Dv) # [Nb,L,h,Dv]
|
| 531 |
+
Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=_parity).squeeze(
|
| 532 |
+
1
|
| 533 |
+
) # [Nb,h,Dv]
|
| 534 |
+
for i, (b, t, g) in enumerate(tgt):
|
| 535 |
+
out[b, t, g] = Ob[i]
|
| 536 |
+
except Exception:
|
| 537 |
+
# Final fallback: per-row SDPA
|
| 538 |
+
for j, (b, t, g) in enumerate(tgt):
|
| 539 |
+
q_btgh = Qb[j].unsqueeze(0).unsqueeze(0) # [1,1,h,Dk]
|
| 540 |
+
k_btgh = Kb[j].unsqueeze(0).unsqueeze(0) # [1,1,L,Dk]
|
| 541 |
+
v_btgh = Vb[j].unsqueeze(0).unsqueeze(0) # [1,1,L,Dv]
|
| 542 |
+
out[b, t, g] = attention_bgh(q_btgh, k_btgh, v_btgh, causal=_parity)[0, 0]
|
| 543 |
+
return out
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def selection_attention_varlen_all_v2(
|
| 547 |
+
Q: torch.Tensor,
|
| 548 |
+
K: torch.Tensor,
|
| 549 |
+
V: torch.Tensor,
|
| 550 |
+
ranges: torch.Tensor,
|
| 551 |
+
) -> torch.Tensor:
|
| 552 |
+
"""
|
| 553 |
+
Vectorized v2 varlen selection packer with FA‑2 varlen fast path and dense fallback.
|
| 554 |
+
- Eliminates Python loops for packing by using a difference-array mask to build per-row
|
| 555 |
+
allowed indices and flat-select K/V tokens.
|
| 556 |
+
- Uses causal=False for single‑query rows.
|
| 557 |
+
- Env: NSA_SEL_VARLEN_MIN_L to bypass on tiny rows (falls back to packed path).
|
| 558 |
+
"""
|
| 559 |
+
B, S, G, h, Dk = Q.shape
|
| 560 |
+
# Parity override: when enabled, force causal=True to match packed reference
|
| 561 |
+
_parity = os.getenv("NSA_SEL_VARLEN_FORCE_PARITY", "0").lower() in ("1", "true", "yes", "on")
|
| 562 |
+
if _parity:
|
| 563 |
+
# Force exact parity by delegating to the packed reference
|
| 564 |
+
return grouped_selection_attention_packed(Q, K, V, ranges)
|
| 565 |
+
device = Q.device
|
| 566 |
+
Dv = V.shape[-1]
|
| 567 |
+
S_kv = K.shape[2]
|
| 568 |
+
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
|
| 569 |
+
if S_kv == 0:
|
| 570 |
+
return out
|
| 571 |
+
|
| 572 |
+
# Build allowed mask [B,S,G,S_kv]
|
| 573 |
+
n = ranges.shape[3]
|
| 574 |
+
starts = ranges[..., 0].to(torch.int64).clamp_(0, S_kv)
|
| 575 |
+
ends = ranges[..., 1].to(torch.int64).clamp_(0, S_kv)
|
| 576 |
+
BSG = B * S * G
|
| 577 |
+
starts_f = starts.reshape(BSG, n)
|
| 578 |
+
ends_f = ends.reshape(BSG, n)
|
| 579 |
+
diff = torch.zeros((BSG, S_kv + 1), dtype=torch.int32, device=device)
|
| 580 |
+
one = torch.ones_like(starts_f, dtype=diff.dtype, device=device)
|
| 581 |
+
diff.scatter_add_(1, starts_f, one)
|
| 582 |
+
diff.scatter_add_(1, ends_f, -one)
|
| 583 |
+
allowed = diff[:, :-1].cumsum(dim=1).gt(0) # [BSG,S_kv]
|
| 584 |
+
|
| 585 |
+
lens_flat = allowed.sum(dim=1, dtype=torch.int32) # [BSG]
|
| 586 |
+
row_mask = lens_flat.gt(0)
|
| 587 |
+
if not torch.any(row_mask):
|
| 588 |
+
return out
|
| 589 |
+
try:
|
| 590 |
+
min_L = int(os.getenv("NSA_SEL_VARLEN_MIN_L", "0"))
|
| 591 |
+
except Exception:
|
| 592 |
+
min_L = 0
|
| 593 |
+
if min_L > 0 and int(lens_flat.max().item()) < min_L:
|
| 594 |
+
return grouped_selection_attention_packed(Q, K, V, ranges)
|
| 595 |
+
|
| 596 |
+
idx_rows = torch.nonzero(row_mask, as_tuple=False).squeeze(1) # [N]
|
| 597 |
+
N = int(idx_rows.numel())
|
| 598 |
+
# (b,t,g) indices for scatter
|
| 599 |
+
b_idx = idx_rows // (S * G)
|
| 600 |
+
rem = idx_rows % (S * G)
|
| 601 |
+
t_idx = rem // G
|
| 602 |
+
g_idx = rem % G
|
| 603 |
+
|
| 604 |
+
# Pack Q rows
|
| 605 |
+
Q_rows = Q.reshape(B * S * G, h, Dk)[idx_rows]
|
| 606 |
+
|
| 607 |
+
# Map rows to b,g to select K/V
|
| 608 |
+
bg_map = (
|
| 609 |
+
torch.arange(B, device=device).view(B, 1, 1) * G
|
| 610 |
+
+ torch.arange(G, device=device).view(1, 1, G)
|
| 611 |
+
).expand(B, S, G)
|
| 612 |
+
bg_rows = bg_map.reshape(B * S * G)[idx_rows]
|
| 613 |
+
K_bg = K.reshape(B * G, S_kv, Dk)[bg_rows]
|
| 614 |
+
V_bg = V.reshape(B * G, S_kv, Dv)[bg_rows]
|
| 615 |
+
allowed_rows = allowed[idx_rows]
|
| 616 |
+
|
| 617 |
+
total_k = int(lens_flat[row_mask].sum().item())
|
| 618 |
+
sel_k = K_bg[allowed_rows] # [total_k, Dk]
|
| 619 |
+
sel_v = V_bg[allowed_rows] # [total_k, Dv]
|
| 620 |
+
lens_sel = lens_flat[row_mask] # [N]
|
| 621 |
+
|
| 622 |
+
# Workspace-backed packing
|
| 623 |
+
ws = _get_varlen_workspace(
|
| 624 |
+
device,
|
| 625 |
+
dtype_q=Q.dtype,
|
| 626 |
+
dtype_k=K.dtype,
|
| 627 |
+
dtype_v=V.dtype,
|
| 628 |
+
h=h,
|
| 629 |
+
d_k=Dk,
|
| 630 |
+
d_v=Dv,
|
| 631 |
+
cap_N=N,
|
| 632 |
+
cap_total_k=total_k,
|
| 633 |
+
)
|
| 634 |
+
q_pack = ws["q"][:N]
|
| 635 |
+
k_pack = ws["k"][:total_k]
|
| 636 |
+
v_pack = ws["v"][:total_k]
|
| 637 |
+
cuq = ws["cuq"][: N + 1]
|
| 638 |
+
cuk = ws["cuk"][: N + 1]
|
| 639 |
+
|
| 640 |
+
q_pack.copy_(Q_rows)
|
| 641 |
+
k_pack.copy_(sel_k.unsqueeze(1).expand(total_k, h, Dk))
|
| 642 |
+
v_pack.copy_(sel_v.unsqueeze(1).expand(total_k, h, Dv))
|
| 643 |
+
cuq.copy_(torch.arange(0, N + 1, device=device, dtype=torch.int32))
|
| 644 |
+
cuk[0] = 0
|
| 645 |
+
torch.cumsum(lens_sel.to(torch.int32), dim=0, out=cuk[1:])
|
| 646 |
+
|
| 647 |
+
# FA‑2 varlen (non-causal)
|
| 648 |
+
ok, _why = fa2_supported_verbose(device, Q.dtype, Dk)
|
| 649 |
+
max_len = int(lens_sel.max().item())
|
| 650 |
+
if ok and is_flash_varlen_available():
|
| 651 |
+
try:
|
| 652 |
+
o_pack = attention_fa2_varlen(
|
| 653 |
+
q_pack,
|
| 654 |
+
k_pack,
|
| 655 |
+
v_pack,
|
| 656 |
+
cuq,
|
| 657 |
+
cuk,
|
| 658 |
+
max_seqlen_q=1,
|
| 659 |
+
max_seqlen_k=max_len,
|
| 660 |
+
causal=_parity,
|
| 661 |
+
)
|
| 662 |
+
out[b_idx, t_idx, g_idx] = o_pack
|
| 663 |
+
return out
|
| 664 |
+
except Exception:
|
| 665 |
+
pass
|
| 666 |
+
|
| 667 |
+
# Correctness-first fallback: masked SDPA over an allowed key mask
|
| 668 |
+
# This path matches the non-causal packed reference exactly and avoids
|
| 669 |
+
# potential packing/indexing pitfalls in dense-bucket fallbacks.
|
| 670 |
+
try:
|
| 671 |
+
return grouped_selection_attention_masked(Q, K, V, ranges)
|
| 672 |
+
except Exception:
|
| 673 |
+
pass
|
| 674 |
+
|
| 675 |
+
# Legacy dense fallback by length buckets (kept as a final fallback)
|
| 676 |
+
starts = cuk[:-1].to(torch.int64)
|
| 677 |
+
ends = cuk[1:].to(torch.int64)
|
| 678 |
+
Ls = (ends - starts).to(torch.int64)
|
| 679 |
+
for L in torch.unique(Ls).tolist():
|
| 680 |
+
if L <= 0:
|
| 681 |
+
continue
|
| 682 |
+
sel = (Ls == L).nonzero(as_tuple=False).squeeze(1)
|
| 683 |
+
if sel.numel() == 0:
|
| 684 |
+
continue
|
| 685 |
+
Nb = int(sel.numel())
|
| 686 |
+
Qb = q_pack[sel]
|
| 687 |
+
k_rows = torch.empty((Nb, L, h, Dk), dtype=K.dtype, device=device)
|
| 688 |
+
v_rows = torch.empty((Nb, L, h, Dv), dtype=V.dtype, device=device)
|
| 689 |
+
for j in range(Nb):
|
| 690 |
+
s0 = int(starts[sel[j]].item())
|
| 691 |
+
e0 = int(ends[sel[j]].item())
|
| 692 |
+
k_rows[j] = k_pack[s0:e0]
|
| 693 |
+
v_rows[j] = v_pack[s0:e0]
|
| 694 |
+
try:
|
| 695 |
+
Ob = attention_fa2_dense_batch(Qb.unsqueeze(1), k_rows, v_rows, causal=_parity).squeeze(1)
|
| 696 |
+
except Exception:
|
| 697 |
+
Ob = torch.empty((Nb, h, Dv), dtype=V.dtype, device=device)
|
| 698 |
+
for j in range(Nb):
|
| 699 |
+
Ob[j] = attention_bgh(Qb[j].unsqueeze(0), k_rows[j].unsqueeze(0), v_rows[j].unsqueeze(0), causal=_parity)[
|
| 700 |
+
0
|
| 701 |
+
]
|
| 702 |
+
out[b_idx[sel], t_idx[sel], g_idx[sel]] = Ob
|
| 703 |
+
return out
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def grouped_selection_attention_masked(
|
| 707 |
+
Q: torch.Tensor, # [B,S,G,h,Dk]
|
| 708 |
+
K: torch.Tensor, # [B,G,S_kv,Dk]
|
| 709 |
+
V: torch.Tensor, # [B,G,S_kv,Dv]
|
| 710 |
+
ranges: torch.Tensor, # [B,S,G,n,2]
|
| 711 |
+
) -> torch.Tensor: # [B,S,G,h,Dv]
|
| 712 |
+
"""
|
| 713 |
+
Fully batched selection attention using an additive -inf mask.
|
| 714 |
+
Vectorized ranges→mask construction via prefix-sum trick (no Python loops).
|
| 715 |
+
"""
|
| 716 |
+
B, S, G, h, Dk = Q.shape
|
| 717 |
+
S_kv = K.shape[2]
|
| 718 |
+
device = Q.device
|
| 719 |
+
if S_kv == 0:
|
| 720 |
+
return torch.zeros((B, S, G, h, V.shape[-1]), dtype=V.dtype, device=device)
|
| 721 |
+
|
| 722 |
+
# Vectorized allowed mask [B,S,G,S_kv] from ranges using difference array
|
| 723 |
+
n = ranges.shape[3]
|
| 724 |
+
starts = ranges[..., 0].to(torch.int64).clamp_(0, S_kv) # [B,S,G,n]
|
| 725 |
+
ends = ranges[..., 1].to(torch.int64).clamp_(0, S_kv) # [B,S,G,n]
|
| 726 |
+
BSG = B * S * G
|
| 727 |
+
starts_f = starts.reshape(BSG, n)
|
| 728 |
+
ends_f = ends.reshape(BSG, n)
|
| 729 |
+
diff = torch.zeros((BSG, S_kv + 1), dtype=torch.int32, device=device)
|
| 730 |
+
one = torch.ones_like(starts_f, dtype=diff.dtype, device=device)
|
| 731 |
+
diff.scatter_add_(1, starts_f, one)
|
| 732 |
+
diff.scatter_add_(1, ends_f, -one)
|
| 733 |
+
allowed = diff[:, :-1].cumsum(dim=1).gt(0).reshape(B, S, G, S_kv)
|
| 734 |
+
|
| 735 |
+
# Detect rows with no allowed keys (all False along key dimension)
|
| 736 |
+
row_has_any = allowed.any(dim=-1) # [B,S,G]
|
| 737 |
+
row_empty = ~row_has_any
|
| 738 |
+
|
| 739 |
+
# Prevent SDPA from seeing an all-−inf row which can produce NaNs.
|
| 740 |
+
# For originally empty rows, force a single safe key (index 0) to True,
|
| 741 |
+
# run SDPA, then zero their outputs afterward to preserve semantics.
|
| 742 |
+
if row_empty.any():
|
| 743 |
+
allowed_safe = allowed.clone()
|
| 744 |
+
flat = allowed_safe.view(B * S * G, S_kv)
|
| 745 |
+
row_empty_flat = row_empty.reshape(B * S * G)
|
| 746 |
+
if S_kv > 0:
|
| 747 |
+
flat[row_empty_flat, 0] = True
|
| 748 |
+
allowed_safe = flat.view_as(allowed_safe)
|
| 749 |
+
else:
|
| 750 |
+
allowed_safe = allowed
|
| 751 |
+
|
| 752 |
+
# Prepare SDPA tensors: [B,G*h,S, D*] and mask [B,G*h,S,S_kv]
|
| 753 |
+
Qf = Q.reshape(B, S, G * h, Dk).transpose(1, 2).contiguous() # [B,G*h,S,Dk]
|
| 754 |
+
Kf = K.unsqueeze(2).expand(-1, -1, h, -1, -1).reshape(B, G * h, S_kv, Dk).contiguous()
|
| 755 |
+
Vf = V.unsqueeze(2).expand(-1, -1, h, -1, -1).reshape(B, G * h, S_kv, V.shape[-1]).contiguous()
|
| 756 |
+
# Build additive mask in float32 for numerical stability with -inf
|
| 757 |
+
zeros = torch.zeros((B, G * h, S, S_kv), dtype=torch.float32, device=device)
|
| 758 |
+
neg_inf = torch.full((B, G * h, S, S_kv), float("-inf"), dtype=torch.float32, device=device)
|
| 759 |
+
Mf = torch.where(
|
| 760 |
+
allowed_safe.transpose(1, 2) # [B,G,S,S_kv]
|
| 761 |
+
.unsqueeze(2)
|
| 762 |
+
.expand(-1, -1, h, -1, -1)
|
| 763 |
+
.reshape(B, G * h, S, S_kv),
|
| 764 |
+
zeros,
|
| 765 |
+
neg_inf,
|
| 766 |
+
).contiguous()
|
| 767 |
+
|
| 768 |
+
Of = F.scaled_dot_product_attention(Qf, Kf, Vf, attn_mask=Mf) # [B,G*h,S,Dv]
|
| 769 |
+
Of = Of.transpose(1, 2).reshape(B, S, G, h, V.shape[-1])
|
| 770 |
+
# Zero outputs for originally empty rows to preserve semantics
|
| 771 |
+
if row_empty.any():
|
| 772 |
+
Of = torch.where(row_has_any.unsqueeze(-1).unsqueeze(-1), Of, torch.zeros_like(Of))
|
| 773 |
+
return Of
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
# ===== FA-2 integration scaffolding (M1) =====
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
def _env_bool(name: str, default: bool = False) -> bool:
|
| 780 |
+
v = os.getenv(name, "1" if default else "0").lower()
|
| 781 |
+
return v in ("1", "true", "yes", "on")
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
def _is_sm89(device: torch.device) -> bool:
|
| 785 |
+
"""Return True if running on CUDA device with SM 8.9 (Ada/RTX 4090)."""
|
| 786 |
+
if device.type != "cuda":
|
| 787 |
+
return False
|
| 788 |
+
try:
|
| 789 |
+
cap = torch.cuda.get_device_capability(device)
|
| 790 |
+
return cap == (8, 9)
|
| 791 |
+
except Exception:
|
| 792 |
+
return False
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
def _fa2_forced() -> bool:
|
| 796 |
+
"""Return True if FA-2 usage is explicitly forced via env."""
|
| 797 |
+
return _env_bool("NSA_FA2_FORCE", False)
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
def sliding_window_attention_fa2(
|
| 801 |
+
Q: torch.Tensor, # [B,S,G,h,Dk]
|
| 802 |
+
K: torch.Tensor, # [B,G,S,Dk]
|
| 803 |
+
V: torch.Tensor, # [B,G,S,Dv]
|
| 804 |
+
w: int,
|
| 805 |
+
min_len_for_fa2: int = 16,
|
| 806 |
+
) -> torch.Tensor:
|
| 807 |
+
"""
|
| 808 |
+
Planned FA-2 path for sliding with safe fallbacks.
|
| 809 |
+
Currently falls back to masked SDPA to preserve numerics until FA-2 is wired.
|
| 810 |
+
"""
|
| 811 |
+
B, S, G, h, Dk = Q.shape
|
| 812 |
+
device = Q.device
|
| 813 |
+
# Policy: sliding FA-2 is disabled by default due to API semantics
|
| 814 |
+
# limitation (causal mask assumes start at 0). Allow only if explicitly
|
| 815 |
+
# enabled via NSA_ALLOW_SLIDING_FA2 or forced flags.
|
| 816 |
+
allow_sliding_fa2 = _env_bool("NSA_ALLOW_SLIDING_FA2", False)
|
| 817 |
+
# Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
|
| 818 |
+
if _is_sm89(device) and not _fa2_forced():
|
| 819 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 820 |
+
log("fa2.gate_skip", branch="win", reason="sm89_guard", forced=bool(_fa2_forced()))
|
| 821 |
+
return sliding_window_attention(Q, K, V, w)
|
| 822 |
+
# Policy guard
|
| 823 |
+
if not allow_sliding_fa2 and not (
|
| 824 |
+
_env_bool("NSA_FA2_FORCE_VARLEN", False) or _env_bool("NSA_FA2_FORCE_DENSE", False)
|
| 825 |
+
):
|
| 826 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 827 |
+
log("fa2.gate_skip", branch="win", reason="unsupported_sliding_semantics", forced=False)
|
| 828 |
+
return sliding_window_attention(Q, K, V, w)
|
| 829 |
+
# Compute effective per-row window lengths and buckets
|
| 830 |
+
lengths = compute_sliding_lengths(S, w, device)
|
| 831 |
+
max_len = int(lengths.max().item()) if lengths.numel() > 0 else 0
|
| 832 |
+
# Allow override via env
|
| 833 |
+
try:
|
| 834 |
+
min_len_for_fa2 = int(os.getenv("NSA_FA2_MIN_LEN_WIN", str(min_len_for_fa2)))
|
| 835 |
+
except Exception:
|
| 836 |
+
pass
|
| 837 |
+
# Disable sentinel: non-positive threshold disables FA‑2 entirely for this branch
|
| 838 |
+
if min_len_for_fa2 <= 0:
|
| 839 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 840 |
+
log("fa2.gate_skip", branch="win", reason="disabled_threshold")
|
| 841 |
+
return sliding_window_attention(Q, K, V, w)
|
| 842 |
+
buckets = build_length_buckets(lengths)
|
| 843 |
+
if buckets:
|
| 844 |
+
log("fa2.win.buckets", n=len(buckets), max_len=max_len)
|
| 845 |
+
# Build cu_seqlens per bucket (for future FA-2 varlen call)
|
| 846 |
+
for idx in buckets:
|
| 847 |
+
blens = lengths[idx]
|
| 848 |
+
_ = build_cu_seqlens_for_buckets(blens)
|
| 849 |
+
# Small-length auto-switch to masked SDPA
|
| 850 |
+
if max_len < min_len_for_fa2:
|
| 851 |
+
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
|
| 852 |
+
log(
|
| 853 |
+
"fa2.gate_skip",
|
| 854 |
+
branch="win",
|
| 855 |
+
reason="below_min_len",
|
| 856 |
+
max_len=int(max_len),
|
| 857 |
+
min_len=int(min_len_for_fa2),
|
| 858 |
+
)
|
| 859 |
+
return sliding_window_attention(Q, K, V, w)
|
| 860 |
+
# Capability check
|
| 861 |
+
ok, why = fa2_supported_verbose(device, Q.dtype, Dk)
|
| 862 |
+
if not ok or not is_flash_varlen_available():
|
| 863 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 864 |
+
log("fa2.gate_skip", branch="win", reason=why, has_varlen=is_flash_varlen_available())
|
| 865 |
+
return sliding_window_attention(Q, K, V, w)
|
| 866 |
+
# Attempt FA-2 across all rows using varlen first, then dense per-bucket. Fallback to masked SDPA on error.
|
| 867 |
+
try:
|
| 868 |
+
B, S, G, h, Dk = Q.shape
|
| 869 |
+
Dv = V.shape[-1]
|
| 870 |
+
use_timing = os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes")
|
| 871 |
+
force_varlen = _env_bool("NSA_FA2_FORCE_VARLEN", False)
|
| 872 |
+
force_dense = _env_bool("NSA_FA2_FORCE_DENSE", False)
|
| 873 |
+
force_win_dense = _env_bool("NSA_WIN_FORCE_DENSE", False)
|
| 874 |
+
# Log histogram of lengths
|
| 875 |
+
if buckets:
|
| 876 |
+
uniq, counts = torch.unique(lengths, return_counts=True)
|
| 877 |
+
log("fa2.win.hist", uniq=uniq.tolist(), counts=counts.tolist())
|
| 878 |
+
# Try a single varlen call across all rows
|
| 879 |
+
if (is_flash_varlen_available() and not (force_dense or force_win_dense)) or force_varlen:
|
| 880 |
+
rows = []
|
| 881 |
+
len_rows = []
|
| 882 |
+
for t in range(S):
|
| 883 |
+
L = int(lengths[t].item())
|
| 884 |
+
for b in range(B):
|
| 885 |
+
for g in range(G):
|
| 886 |
+
rows.append((b, t, g))
|
| 887 |
+
len_rows.append(L)
|
| 888 |
+
N = len(rows)
|
| 889 |
+
if N > 0 and max_len >= 1:
|
| 890 |
+
use_safe_pack = (
|
| 891 |
+
torch.is_grad_enabled()
|
| 892 |
+
and (Q.requires_grad or K.requires_grad or V.requires_grad)
|
| 893 |
+
) or _env_bool("NSA_TRAIN_SAFE_PACK", False)
|
| 894 |
+
if use_safe_pack:
|
| 895 |
+
# Autograd-safe packing via stack/cat to preserve graph links
|
| 896 |
+
q_pack = torch.stack([Q[b, t, g] for (b, t, g) in rows], dim=0) # [N,h,Dk]
|
| 897 |
+
k_rows = []
|
| 898 |
+
v_rows = []
|
| 899 |
+
for i, (b, t, g) in enumerate(rows):
|
| 900 |
+
L = len_rows[i]
|
| 901 |
+
if L > 0:
|
| 902 |
+
start = max(0, (t + 1) - w)
|
| 903 |
+
end = t + 1
|
| 904 |
+
seg_k = K[b, g, start:end].unsqueeze(1).expand(-1, h, -1) # [L,h,Dk]
|
| 905 |
+
seg_v = V[b, g, start:end].unsqueeze(1).expand(-1, h, -1) # [L,h,Dv]
|
| 906 |
+
k_rows.append(seg_k)
|
| 907 |
+
v_rows.append(seg_v)
|
| 908 |
+
total_k = int(sum(len_rows))
|
| 909 |
+
if total_k > 0:
|
| 910 |
+
k_pack = torch.cat(k_rows, dim=0)
|
| 911 |
+
v_pack = torch.cat(v_rows, dim=0)
|
| 912 |
+
else:
|
| 913 |
+
k_pack = torch.zeros((0, h, Dk), dtype=K.dtype, device=K.device)
|
| 914 |
+
v_pack = torch.zeros((0, h, Dv), dtype=V.dtype, device=V.device)
|
| 915 |
+
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
|
| 916 |
+
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
|
| 917 |
+
cuk = torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0)
|
| 918 |
+
else:
|
| 919 |
+
total_k = int(sum(len_rows))
|
| 920 |
+
ws = _get_varlen_workspace(
|
| 921 |
+
Q.device, Q.dtype, K.dtype, V.dtype, h, Dk, Dv, N, total_k
|
| 922 |
+
)
|
| 923 |
+
q_pack = ws["q"][:N]
|
| 924 |
+
k_pack = ws["k"][:total_k]
|
| 925 |
+
v_pack = ws["v"][:total_k]
|
| 926 |
+
# Build cumulative sequence lengths for Q and K
|
| 927 |
+
cuq = ws["cuq"][: N + 1]
|
| 928 |
+
cuq.copy_(torch.arange(0, N + 1, device=Q.device, dtype=torch.int32))
|
| 929 |
+
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
|
| 930 |
+
cuk = ws["cuk"][: N + 1]
|
| 931 |
+
torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0, out=cuk)
|
| 932 |
+
# Fill packs
|
| 933 |
+
write_pos = 0
|
| 934 |
+
for i, (b, t, g) in enumerate(rows):
|
| 935 |
+
L = len_rows[i]
|
| 936 |
+
q_pack[i] = Q[b, t, g]
|
| 937 |
+
if L > 0:
|
| 938 |
+
start = max(0, (t + 1) - w)
|
| 939 |
+
end = t + 1
|
| 940 |
+
seg_k = K[b, g, start:end] # [L,Dk]
|
| 941 |
+
seg_v = V[b, g, start:end] # [L,Dv]
|
| 942 |
+
assert (write_pos + L) <= total_k, "varlen K/V pack overflow"
|
| 943 |
+
k_pack[write_pos : write_pos + L] = seg_k.unsqueeze(1).expand(L, h, Dk)
|
| 944 |
+
v_pack[write_pos : write_pos + L] = seg_v.unsqueeze(1).expand(L, h, Dv)
|
| 945 |
+
write_pos += L
|
| 946 |
+
# Optional integrity checks (debug only)
|
| 947 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 948 |
+
try:
|
| 949 |
+
assert cuq.numel() == (N + 1), "cuq length mismatch"
|
| 950 |
+
assert cuk.numel() == (N + 1), "cuk length mismatch"
|
| 951 |
+
assert int(cuk[-1].item()) == int(total_k), "cuk total_k mismatch"
|
| 952 |
+
if total_k > 0 and N > 0:
|
| 953 |
+
probe = [0, N // 2, N - 1] if N >= 3 else [0]
|
| 954 |
+
for i in probe:
|
| 955 |
+
L_i = int(len_rows[i])
|
| 956 |
+
b_i, t_i, g_i = rows[i]
|
| 957 |
+
s_i = int(max(0, (t_i + 1) - w))
|
| 958 |
+
e_i = int(t_i + 1)
|
| 959 |
+
if L_i > 0:
|
| 960 |
+
ks = k_pack[cuk[i] : cuk[i + 1]] # [L,h,Dk]
|
| 961 |
+
kv = K[b_i, g_i, s_i:e_i].unsqueeze(1).expand(-1, h, -1)
|
| 962 |
+
if ks.shape != kv.shape:
|
| 963 |
+
log(
|
| 964 |
+
"warn.fa2_win_pack_shape",
|
| 965 |
+
row=i,
|
| 966 |
+
ks=ks.shape,
|
| 967 |
+
kv=kv.shape,
|
| 968 |
+
)
|
| 969 |
+
else:
|
| 970 |
+
md = float((ks - kv).abs().max().item())
|
| 971 |
+
if md > 1e-3:
|
| 972 |
+
log(
|
| 973 |
+
"warn.fa2_win_pack_mismatch",
|
| 974 |
+
row=i,
|
| 975 |
+
L=L_i,
|
| 976 |
+
max_diff=md,
|
| 977 |
+
)
|
| 978 |
+
except Exception:
|
| 979 |
+
pass
|
| 980 |
+
|
| 981 |
+
if use_timing:
|
| 982 |
+
t0 = time.perf_counter()
|
| 983 |
+
o_pack = attention_fa2_varlen(
|
| 984 |
+
q_pack,
|
| 985 |
+
k_pack,
|
| 986 |
+
v_pack,
|
| 987 |
+
cuq,
|
| 988 |
+
cuk,
|
| 989 |
+
max_seqlen_q=1,
|
| 990 |
+
max_seqlen_k=max_len,
|
| 991 |
+
causal=False,
|
| 992 |
+
) # [N,h,Dv]
|
| 993 |
+
if not torch.isfinite(o_pack).all():
|
| 994 |
+
log("warn.fa2_win_varlen_nonfinite")
|
| 995 |
+
return sliding_window_attention(Q, K, V, w)
|
| 996 |
+
if use_timing:
|
| 997 |
+
dt = (time.perf_counter() - t0) * 1e3
|
| 998 |
+
log("fa2.win.varlen_all", N=int(N), total_k=int(total_k), ms=dt)
|
| 999 |
+
# Scatter back
|
| 1000 |
+
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
|
| 1001 |
+
for i, (b, t, g) in enumerate(rows):
|
| 1002 |
+
out[b, t, g] = o_pack[i]
|
| 1003 |
+
return out
|
| 1004 |
+
out = torch.zeros((B, S, G, h, Dv), dtype=V.dtype, device=V.device)
|
| 1005 |
+
for idx in buckets:
|
| 1006 |
+
if idx.numel() == 0:
|
| 1007 |
+
continue
|
| 1008 |
+
L = int(lengths[idx[0]].item())
|
| 1009 |
+
# Collect rows for this bucket
|
| 1010 |
+
rows_q = [] # [N,h,Dk]
|
| 1011 |
+
rows_k = [] # [N,L,Dk]
|
| 1012 |
+
rows_v = [] # [N,L,Dv]
|
| 1013 |
+
tgt = []
|
| 1014 |
+
for t in idx.tolist():
|
| 1015 |
+
start = max(0, (t + 1) - w)
|
| 1016 |
+
end = t + 1
|
| 1017 |
+
for b in range(B):
|
| 1018 |
+
for g in range(G):
|
| 1019 |
+
rows_q.append(Q[b, t, g])
|
| 1020 |
+
rows_k.append(K[b, g, start:end])
|
| 1021 |
+
rows_v.append(V[b, g, start:end])
|
| 1022 |
+
tgt.append((b, t, g))
|
| 1023 |
+
if not rows_q:
|
| 1024 |
+
continue
|
| 1025 |
+
N = len(rows_q)
|
| 1026 |
+
Qb = torch.stack(rows_q, dim=0) # [N,h,Dk]
|
| 1027 |
+
Kb = torch.stack(rows_k, dim=0) # [N,L,Dk]
|
| 1028 |
+
Vb = torch.stack(rows_v, dim=0) # [N,L,Dv]
|
| 1029 |
+
if is_flash_varlen_available() and not (force_dense or force_win_dense):
|
| 1030 |
+
# Pack varlen (constant L here, but use API for generality)
|
| 1031 |
+
q_pack = Qb # [N,h,Dk]
|
| 1032 |
+
k_pack = Kb.reshape(N * L, Dk).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dk)
|
| 1033 |
+
v_pack = Vb.reshape(N * L, Dv).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dv)
|
| 1034 |
+
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
|
| 1035 |
+
cuk = torch.arange(0, (N + 1) * L, step=L, device=Q.device, dtype=torch.int32)
|
| 1036 |
+
if use_timing:
|
| 1037 |
+
t0 = time.perf_counter()
|
| 1038 |
+
o_pack = attention_fa2_varlen(
|
| 1039 |
+
q_pack,
|
| 1040 |
+
k_pack,
|
| 1041 |
+
v_pack,
|
| 1042 |
+
cuq,
|
| 1043 |
+
cuk,
|
| 1044 |
+
max_seqlen_q=1,
|
| 1045 |
+
max_seqlen_k=L,
|
| 1046 |
+
causal=False,
|
| 1047 |
+
) # [N,h,Dv]
|
| 1048 |
+
if not torch.isfinite(o_pack).all():
|
| 1049 |
+
log("warn.fa2_win_bucket_nonfinite")
|
| 1050 |
+
return sliding_window_attention(Q, K, V, w)
|
| 1051 |
+
if use_timing:
|
| 1052 |
+
dt = (time.perf_counter() - t0) * 1e3
|
| 1053 |
+
log("fa2.win.bucket", path="varlen", L=L, N=int(N), ms=dt)
|
| 1054 |
+
Ob = o_pack # [N,h,Dv]
|
| 1055 |
+
else:
|
| 1056 |
+
q_rows = Qb.unsqueeze(1) # [N,1,h,Dk]
|
| 1057 |
+
k_rows = Kb.unsqueeze(2).expand(N, L, h, Dk)
|
| 1058 |
+
v_rows = Vb.unsqueeze(2).expand(N, L, h, Dv)
|
| 1059 |
+
if use_timing:
|
| 1060 |
+
t0 = time.perf_counter()
|
| 1061 |
+
Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False).squeeze(
|
| 1062 |
+
1
|
| 1063 |
+
) # [N,h,Dv]
|
| 1064 |
+
if use_timing:
|
| 1065 |
+
dt = (time.perf_counter() - t0) * 1e3
|
| 1066 |
+
log("fa2.win.bucket", path="dense", L=L, N=int(N), ms=dt)
|
| 1067 |
+
for i, (b, t, g) in enumerate(tgt):
|
| 1068 |
+
out[b, t, g] = Ob[i]
|
| 1069 |
+
return out
|
| 1070 |
+
except Exception as e:
|
| 1071 |
+
log("warn.fa2_unexpected_fallback", branch="win", error=str(e)[:100])
|
| 1072 |
+
return sliding_window_attention_masked(Q, K, V, w)
|
| 1073 |
+
|
| 1074 |
+
|
| 1075 |
+
def compressed_attention_fa2(
|
| 1076 |
+
Q: torch.Tensor, # [B,S,G,h,Dk]
|
| 1077 |
+
K_cmp: torch.Tensor, # [B,G,S_cmp,Dk]
|
| 1078 |
+
V_cmp: torch.Tensor, # [B,G,S_cmp,Dv]
|
| 1079 |
+
l: int,
|
| 1080 |
+
d: int,
|
| 1081 |
+
min_len_for_fa2: int = 16,
|
| 1082 |
+
) -> torch.Tensor:
|
| 1083 |
+
"""
|
| 1084 |
+
Planned FA-2 path for compressed with safe fallbacks.
|
| 1085 |
+
Currently falls back to masked SDPA to preserve numerics until FA-2 is wired.
|
| 1086 |
+
"""
|
| 1087 |
+
B, S, G, h, Dk = Q.shape
|
| 1088 |
+
device = Q.device
|
| 1089 |
+
# Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
|
| 1090 |
+
if _is_sm89(device) and not _fa2_forced():
|
| 1091 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 1092 |
+
log("fa2.gate_skip", branch="cmp", reason="sm89_guard", forced=bool(_fa2_forced()))
|
| 1093 |
+
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
|
| 1094 |
+
S_cmp = K_cmp.shape[2]
|
| 1095 |
+
if S_cmp == 0:
|
| 1096 |
+
return torch.zeros((B, S, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
|
| 1097 |
+
num_cmp = compute_compressed_lengths(S, l, d, S_cmp, device)
|
| 1098 |
+
max_len = int(num_cmp.max().item()) if num_cmp.numel() > 0 else 0
|
| 1099 |
+
try:
|
| 1100 |
+
min_len_for_fa2 = int(os.getenv("NSA_FA2_MIN_LEN_CMP", str(min_len_for_fa2)))
|
| 1101 |
+
except Exception:
|
| 1102 |
+
pass
|
| 1103 |
+
# Disable sentinel: non-positive threshold disables FA‑2 entirely for this branch
|
| 1104 |
+
if min_len_for_fa2 <= 0:
|
| 1105 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 1106 |
+
log("fa2.gate_skip", branch="cmp", reason="disabled_threshold")
|
| 1107 |
+
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
|
| 1108 |
+
buckets = build_length_buckets(num_cmp)
|
| 1109 |
+
if buckets:
|
| 1110 |
+
log("fa2.cmp.buckets", n=len(buckets), max_len=max_len)
|
| 1111 |
+
for idx in buckets:
|
| 1112 |
+
blens = num_cmp[idx]
|
| 1113 |
+
_ = build_cu_seqlens_for_buckets(blens)
|
| 1114 |
+
if max_len < min_len_for_fa2:
|
| 1115 |
+
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
|
| 1116 |
+
log(
|
| 1117 |
+
"fa2.gate_skip",
|
| 1118 |
+
branch="cmp",
|
| 1119 |
+
reason="below_min_len",
|
| 1120 |
+
max_len=int(max_len),
|
| 1121 |
+
min_len=int(min_len_for_fa2),
|
| 1122 |
+
)
|
| 1123 |
+
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
|
| 1124 |
+
ok, why = fa2_supported_verbose(device, Q.dtype, Dk)
|
| 1125 |
+
if not ok or not is_flash_varlen_available():
|
| 1126 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 1127 |
+
log("fa2.gate_skip", branch="cmp", reason=why, has_varlen=is_flash_varlen_available())
|
| 1128 |
+
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
|
| 1129 |
+
try:
|
| 1130 |
+
Dv = V_cmp.shape[-1]
|
| 1131 |
+
use_timing = os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes")
|
| 1132 |
+
# Log histogram of lengths
|
| 1133 |
+
if buckets:
|
| 1134 |
+
uniq, counts = torch.unique(num_cmp, return_counts=True)
|
| 1135 |
+
log("fa2.cmp.hist", uniq=uniq.tolist(), counts=counts.tolist())
|
| 1136 |
+
# Try single varlen across all rows with L>0
|
| 1137 |
+
force_varlen = _env_bool("NSA_FA2_FORCE_VARLEN", False)
|
| 1138 |
+
force_dense = _env_bool("NSA_FA2_FORCE_DENSE", False)
|
| 1139 |
+
if ((is_flash_varlen_available() and not force_dense) or force_varlen) and max_len >= 1:
|
| 1140 |
+
rows = []
|
| 1141 |
+
len_rows = []
|
| 1142 |
+
for t in range(S):
|
| 1143 |
+
L = int(num_cmp[t].item())
|
| 1144 |
+
for b in range(B):
|
| 1145 |
+
for g in range(G):
|
| 1146 |
+
if L > 0:
|
| 1147 |
+
rows.append((b, t, g))
|
| 1148 |
+
len_rows.append(L)
|
| 1149 |
+
N = len(rows)
|
| 1150 |
+
if N > 0:
|
| 1151 |
+
total_k = int(sum(len_rows))
|
| 1152 |
+
use_safe_pack = (
|
| 1153 |
+
torch.is_grad_enabled()
|
| 1154 |
+
and (Q.requires_grad or K_cmp.requires_grad or V_cmp.requires_grad)
|
| 1155 |
+
) or _env_bool("NSA_TRAIN_SAFE_PACK", False)
|
| 1156 |
+
if use_safe_pack:
|
| 1157 |
+
q_pack = torch.stack([Q[b, t, g] for (b, t, g) in rows], dim=0)
|
| 1158 |
+
k_rows = []
|
| 1159 |
+
v_rows = []
|
| 1160 |
+
for (b, t, g), L in zip(rows, len_rows):
|
| 1161 |
+
if L > 0:
|
| 1162 |
+
seg_k = K_cmp[b, g, :L]
|
| 1163 |
+
seg_v = V_cmp[b, g, :L]
|
| 1164 |
+
k_rows.append(seg_k.unsqueeze(1).expand(-1, h, -1)) # [L,h,Dk]
|
| 1165 |
+
v_rows.append(seg_v.unsqueeze(1).expand(-1, h, -1)) # [L,h,Dv]
|
| 1166 |
+
if total_k > 0:
|
| 1167 |
+
k_pack = torch.cat(k_rows, dim=0)
|
| 1168 |
+
v_pack = torch.cat(v_rows, dim=0)
|
| 1169 |
+
else:
|
| 1170 |
+
k_pack = torch.zeros((0, h, Dk), dtype=K_cmp.dtype, device=K_cmp.device)
|
| 1171 |
+
v_pack = torch.zeros((0, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device)
|
| 1172 |
+
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
|
| 1173 |
+
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
|
| 1174 |
+
cuk = torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0)
|
| 1175 |
+
else:
|
| 1176 |
+
ws = _get_varlen_workspace(
|
| 1177 |
+
Q.device, Q.dtype, K_cmp.dtype, V_cmp.dtype, h, Dk, Dv, N, total_k
|
| 1178 |
+
)
|
| 1179 |
+
q_pack = ws["q"][:N]
|
| 1180 |
+
k_pack = ws["k"][:total_k]
|
| 1181 |
+
v_pack = ws["v"][:total_k]
|
| 1182 |
+
cuq = ws["cuq"][: N + 1]
|
| 1183 |
+
cuq.copy_(torch.arange(0, N + 1, device=Q.device, dtype=torch.int32))
|
| 1184 |
+
lens_t = torch.tensor(len_rows, dtype=torch.int32, device=Q.device)
|
| 1185 |
+
cuk = ws["cuk"][: N + 1]
|
| 1186 |
+
torch.cumsum(torch.nn.functional.pad(lens_t, (1, 0)), dim=0, out=cuk)
|
| 1187 |
+
write_pos = 0
|
| 1188 |
+
for i, (b, t, g) in enumerate(rows):
|
| 1189 |
+
L = len_rows[i]
|
| 1190 |
+
q_pack[i] = Q[b, t, g]
|
| 1191 |
+
if L > 0:
|
| 1192 |
+
seg_k = K_cmp[b, g, :L]
|
| 1193 |
+
seg_v = V_cmp[b, g, :L]
|
| 1194 |
+
assert (write_pos + L) <= total_k, "varlen cmp K/V pack overflow"
|
| 1195 |
+
k_pack[write_pos : write_pos + L] = seg_k.unsqueeze(1).expand(L, h, Dk)
|
| 1196 |
+
v_pack[write_pos : write_pos + L] = seg_v.unsqueeze(1).expand(L, h, Dv)
|
| 1197 |
+
write_pos += L
|
| 1198 |
+
if use_timing:
|
| 1199 |
+
t0 = time.perf_counter()
|
| 1200 |
+
o_pack = attention_fa2_varlen(
|
| 1201 |
+
q_pack,
|
| 1202 |
+
k_pack,
|
| 1203 |
+
v_pack,
|
| 1204 |
+
cuq,
|
| 1205 |
+
cuk,
|
| 1206 |
+
max_seqlen_q=1,
|
| 1207 |
+
max_seqlen_k=max_len,
|
| 1208 |
+
causal=False,
|
| 1209 |
+
) # [N,h,Dv]
|
| 1210 |
+
if not torch.isfinite(o_pack).all():
|
| 1211 |
+
log("warn.fa2_cmp_varlen_nonfinite")
|
| 1212 |
+
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
|
| 1213 |
+
if use_timing:
|
| 1214 |
+
dt = (time.perf_counter() - t0) * 1e3
|
| 1215 |
+
log("fa2.cmp.varlen_all", N=int(N), total_k=int(total_k), ms=dt)
|
| 1216 |
+
out = torch.zeros((B, S, G, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device)
|
| 1217 |
+
for i, (b, t, g) in enumerate(rows):
|
| 1218 |
+
out[b, t, g] = o_pack[i]
|
| 1219 |
+
return out
|
| 1220 |
+
out = torch.zeros((B, S, G, h, Dv), dtype=V_cmp.dtype, device=V_cmp.device)
|
| 1221 |
+
for idx in buckets:
|
| 1222 |
+
if idx.numel() == 0:
|
| 1223 |
+
continue
|
| 1224 |
+
L = int(num_cmp[idx[0]].item())
|
| 1225 |
+
rows_q = [] # [N,h,Dk]
|
| 1226 |
+
rows_k = [] # [N,L,Dk]
|
| 1227 |
+
rows_v = [] # [N,L, Dv]
|
| 1228 |
+
tgt = []
|
| 1229 |
+
for t in idx.tolist():
|
| 1230 |
+
if L <= 0:
|
| 1231 |
+
continue
|
| 1232 |
+
for b in range(B):
|
| 1233 |
+
for g in range(G):
|
| 1234 |
+
rows_q.append(Q[b, t, g])
|
| 1235 |
+
rows_k.append(K_cmp[b, g, :L])
|
| 1236 |
+
rows_v.append(V_cmp[b, g, :L])
|
| 1237 |
+
tgt.append((b, t, g))
|
| 1238 |
+
if not rows_q:
|
| 1239 |
+
continue
|
| 1240 |
+
N = len(rows_q)
|
| 1241 |
+
Qb = torch.stack(rows_q, dim=0) # [N,h,Dk]
|
| 1242 |
+
Kb = torch.stack(rows_k, dim=0) # [N,L,Dk]
|
| 1243 |
+
Vb = torch.stack(rows_v, dim=0) # [N,L,Dv]
|
| 1244 |
+
if is_flash_varlen_available() and not force_dense:
|
| 1245 |
+
q_pack = Qb
|
| 1246 |
+
k_pack = Kb.reshape(N * L, Dk).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dk)
|
| 1247 |
+
v_pack = Vb.reshape(N * L, Dv).unsqueeze(1).expand(-1, h, -1).reshape(N * L, h, Dv)
|
| 1248 |
+
cuq = torch.arange(0, N + 1, device=Q.device, dtype=torch.int32)
|
| 1249 |
+
cuk = torch.arange(0, (N + 1) * L, step=L, device=Q.device, dtype=torch.int32)
|
| 1250 |
+
if use_timing:
|
| 1251 |
+
t0 = time.perf_counter()
|
| 1252 |
+
o_pack = attention_fa2_varlen(
|
| 1253 |
+
q_pack,
|
| 1254 |
+
k_pack,
|
| 1255 |
+
v_pack,
|
| 1256 |
+
cuq,
|
| 1257 |
+
cuk,
|
| 1258 |
+
max_seqlen_q=1,
|
| 1259 |
+
max_seqlen_k=L,
|
| 1260 |
+
causal=False,
|
| 1261 |
+
) # [N,h,Dv]
|
| 1262 |
+
if use_timing:
|
| 1263 |
+
dt = (time.perf_counter() - t0) * 1e3
|
| 1264 |
+
log("fa2.cmp.bucket", path="varlen", L=L, N=int(N), ms=dt)
|
| 1265 |
+
Ob = o_pack
|
| 1266 |
+
else:
|
| 1267 |
+
q_rows = Qb.unsqueeze(1)
|
| 1268 |
+
k_rows = Kb.unsqueeze(2).expand(N, L, h, Dk)
|
| 1269 |
+
v_rows = Vb.unsqueeze(2).expand(N, L, h, Dv)
|
| 1270 |
+
if use_timing:
|
| 1271 |
+
t0 = time.perf_counter()
|
| 1272 |
+
Ob = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=True).squeeze(1)
|
| 1273 |
+
if not torch.isfinite(Ob).all():
|
| 1274 |
+
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
|
| 1275 |
+
if use_timing:
|
| 1276 |
+
dt = (time.perf_counter() - t0) * 1e3
|
| 1277 |
+
log("fa2.cmp.bucket", path="dense", L=L, N=int(N), ms=dt)
|
| 1278 |
+
for i, (b, t, g) in enumerate(tgt):
|
| 1279 |
+
out[b, t, g] = Ob[i]
|
| 1280 |
+
return out
|
| 1281 |
+
except Exception as e:
|
| 1282 |
+
log("warn.fa2_unexpected_fallback", branch="cmp", error=str(e)[:100])
|
| 1283 |
+
return batched_causal_attention_compressed_masked(Q, K_cmp, V_cmp, l, d)
|
| 1284 |
+
|
| 1285 |
+
|
| 1286 |
+
def sliding_window_attention_fa2_decode(
|
| 1287 |
+
q_t: torch.Tensor, K_win: torch.Tensor, V_win: torch.Tensor, w: int
|
| 1288 |
+
) -> torch.Tensor:
|
| 1289 |
+
B, G, h, Dk = q_t.shape
|
| 1290 |
+
# Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
|
| 1291 |
+
if _is_sm89(q_t.device) and not _fa2_forced():
|
| 1292 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 1293 |
+
log(
|
| 1294 |
+
"fa2.gate_skip",
|
| 1295 |
+
branch="win.decode",
|
| 1296 |
+
reason="sm89_guard",
|
| 1297 |
+
forced=bool(_fa2_forced()),
|
| 1298 |
+
)
|
| 1299 |
+
end = K_win.shape[2]
|
| 1300 |
+
win_len = min(w, end)
|
| 1301 |
+
if win_len == 0:
|
| 1302 |
+
return torch.zeros((B, G, h, V_win.shape[-1]), dtype=V_win.dtype, device=V_win.device)
|
| 1303 |
+
start = end - win_len
|
| 1304 |
+
return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True)
|
| 1305 |
+
end = K_win.shape[2]
|
| 1306 |
+
win_len = min(w, end)
|
| 1307 |
+
if win_len == 0:
|
| 1308 |
+
return torch.zeros((B, G, h, V_win.shape[-1]), dtype=V_win.dtype, device=V_win.device)
|
| 1309 |
+
# CPU or unsupported: direct SDPA for parity
|
| 1310 |
+
ok, why = fa2_supported_verbose(q_t.device, q_t.dtype, Dk)
|
| 1311 |
+
if not ok:
|
| 1312 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 1313 |
+
log("fa2.gate_skip", branch="win.decode", reason=why)
|
| 1314 |
+
start = end - win_len
|
| 1315 |
+
return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True)
|
| 1316 |
+
# Small-length auto-switch for decode
|
| 1317 |
+
try:
|
| 1318 |
+
min_len = int(os.getenv("NSA_FA2_MIN_LEN_WIN", "16"))
|
| 1319 |
+
except Exception:
|
| 1320 |
+
min_len = 16
|
| 1321 |
+
if min_len < 1:
|
| 1322 |
+
min_len = 1
|
| 1323 |
+
if win_len < min_len:
|
| 1324 |
+
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
|
| 1325 |
+
log(
|
| 1326 |
+
"fa2.gate_skip",
|
| 1327 |
+
branch="win.decode",
|
| 1328 |
+
reason="below_min_len",
|
| 1329 |
+
win_len=int(win_len),
|
| 1330 |
+
min_len=int(min_len),
|
| 1331 |
+
)
|
| 1332 |
+
start = end - win_len
|
| 1333 |
+
return attention_bgh(q_t, K_win[:, :, start:end], V_win[:, :, start:end], causal=True)
|
| 1334 |
+
start = end - win_len
|
| 1335 |
+
k = K_win[:, :, start:end]
|
| 1336 |
+
v = V_win[:, :, start:end]
|
| 1337 |
+
N = B * G
|
| 1338 |
+
q_rows = q_t.reshape(N, h, Dk).unsqueeze(1) # [N,1,h,Dk]
|
| 1339 |
+
k_rows = k.reshape(N, win_len, Dk).unsqueeze(2).expand(N, win_len, h, Dk)
|
| 1340 |
+
v_rows = v.reshape(N, win_len, v.shape[-1]).unsqueeze(2).expand(N, win_len, h, v.shape[-1])
|
| 1341 |
+
try:
|
| 1342 |
+
o = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False) # [N,1,h,Dv]
|
| 1343 |
+
o = o.squeeze(1).reshape(B, G, h, -1)
|
| 1344 |
+
if not torch.isfinite(o).all():
|
| 1345 |
+
return attention_bgh(q_t, k, v, causal=True)
|
| 1346 |
+
return o
|
| 1347 |
+
except Exception as e:
|
| 1348 |
+
log("warn.fa2_unexpected_fallback", branch="win.decode", error=str(e)[:100])
|
| 1349 |
+
return attention_bgh(q_t, k, v, causal=True)
|
| 1350 |
+
|
| 1351 |
+
|
| 1352 |
+
def compressed_attention_fa2_decode(
|
| 1353 |
+
q_t: torch.Tensor, K_cmp: torch.Tensor, V_cmp: torch.Tensor, L: int
|
| 1354 |
+
) -> torch.Tensor:
|
| 1355 |
+
if L <= 0:
|
| 1356 |
+
B, G, h, _ = q_t.shape
|
| 1357 |
+
return torch.zeros((B, G, h, V_cmp.shape[-1]), dtype=V_cmp.dtype, device=V_cmp.device)
|
| 1358 |
+
B, G, h, Dk = q_t.shape
|
| 1359 |
+
# Guard: disable FA-2 on Ada (SM 8.9) unless explicitly forced
|
| 1360 |
+
if _is_sm89(q_t.device) and not _fa2_forced():
|
| 1361 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 1362 |
+
log(
|
| 1363 |
+
"fa2.gate_skip",
|
| 1364 |
+
branch="cmp.decode",
|
| 1365 |
+
reason="sm89_guard",
|
| 1366 |
+
forced=bool(_fa2_forced()),
|
| 1367 |
+
)
|
| 1368 |
+
return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True)
|
| 1369 |
+
ok, why = fa2_supported_verbose(q_t.device, q_t.dtype, Dk)
|
| 1370 |
+
if not ok:
|
| 1371 |
+
if os.getenv("NSA_SDPA_AUDIT", "0").lower() in ("1", "true", "yes"):
|
| 1372 |
+
log("fa2.gate_skip", branch="cmp.decode", reason=why)
|
| 1373 |
+
return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True)
|
| 1374 |
+
try:
|
| 1375 |
+
min_len = int(os.getenv("NSA_FA2_MIN_LEN_CMP", "16"))
|
| 1376 |
+
except Exception:
|
| 1377 |
+
min_len = 16
|
| 1378 |
+
if min_len < 1:
|
| 1379 |
+
min_len = 1
|
| 1380 |
+
if L < min_len:
|
| 1381 |
+
if os.getenv("NSA_DEBUG_TIMING", "0").lower() in ("1", "true", "yes"):
|
| 1382 |
+
log(
|
| 1383 |
+
"fa2.gate_skip",
|
| 1384 |
+
branch="cmp.decode",
|
| 1385 |
+
reason="below_min_len",
|
| 1386 |
+
L=int(L),
|
| 1387 |
+
min_len=int(min_len),
|
| 1388 |
+
)
|
| 1389 |
+
return attention_bgh(q_t, K_cmp[:, :, :L], V_cmp[:, :, :L], causal=True)
|
| 1390 |
+
k = K_cmp[:, :, :L]
|
| 1391 |
+
v = V_cmp[:, :, :L]
|
| 1392 |
+
N = B * G
|
| 1393 |
+
q_rows = q_t.reshape(N, h, Dk).unsqueeze(1)
|
| 1394 |
+
k_rows = k.reshape(N, L, Dk).unsqueeze(2).expand(N, L, h, Dk)
|
| 1395 |
+
v_rows = v.reshape(N, L, v.shape[-1]).unsqueeze(2).expand(N, L, h, v.shape[-1])
|
| 1396 |
+
try:
|
| 1397 |
+
o = attention_fa2_dense_batch(q_rows, k_rows, v_rows, causal=False)
|
| 1398 |
+
o = o.squeeze(1).reshape(B, G, h, -1)
|
| 1399 |
+
if not torch.isfinite(o).all():
|
| 1400 |
+
return attention_bgh(q_t, k, v, causal=True)
|
| 1401 |
+
return o
|
| 1402 |
+
except Exception:
|
| 1403 |
+
return attention_bgh(q_t, k, v, causal=True)
|
nsa/core/block_index.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class BlockMeta:
|
| 10 |
+
l: int
|
| 11 |
+
d: int
|
| 12 |
+
l_sel: int
|
| 13 |
+
n_sel: int
|
| 14 |
+
w: int
|
| 15 |
+
cmp_starts: torch.Tensor # [S_cmp]
|
| 16 |
+
sel_starts: torch.Tensor # [S_sel]
|
| 17 |
+
# CSR representation: (indptr, indices, values) mapping cmp_idx -> {sel_idx: weight}
|
| 18 |
+
M_csl_indptr: torch.Tensor
|
| 19 |
+
M_csl_indices: torch.Tensor
|
| 20 |
+
M_csl_values: torch.Tensor
|
| 21 |
+
# COO representation for fast batched matmul
|
| 22 |
+
M_csl_coo_indices: torch.Tensor # [2, nnz] rows, cols
|
| 23 |
+
M_csl_coo_values: torch.Tensor # [nnz]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def build_block_starts(
|
| 27 |
+
seq_len: int, l: int, d: int, l_sel: int
|
| 28 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 29 |
+
if d <= 0 or l <= 0 or l_sel <= 0:
|
| 30 |
+
raise ValueError("Block parameters must be positive")
|
| 31 |
+
# compression blocks (overlapped)
|
| 32 |
+
max_cmp = 0 if seq_len < l else (seq_len - l) // d + 1
|
| 33 |
+
cmp_starts = torch.arange(max_cmp, dtype=torch.int32) * d
|
| 34 |
+
# selection blocks (non-overlapped)
|
| 35 |
+
max_sel = 0 if seq_len <= 0 else (seq_len + l_sel - 1) // l_sel
|
| 36 |
+
sel_starts = torch.arange(max_sel, dtype=torch.int32) * l_sel
|
| 37 |
+
return cmp_starts, sel_starts
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _overlap_len(a0: int, a1: int, b0: int, b1: int) -> int:
|
| 41 |
+
return max(0, min(a1, b1) - max(a0, b0))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def build_M_csl_csr(
|
| 45 |
+
seq_len: int, l: int, d: int, l_sel: int
|
| 46 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 47 |
+
# Build CSR with fractional-overlap weights from cmp blocks to sel blocks
|
| 48 |
+
cmp_starts, sel_starts = build_block_starts(seq_len, l, d, l_sel)
|
| 49 |
+
indptr = [0]
|
| 50 |
+
indices: List[int] = []
|
| 51 |
+
values: List[float] = []
|
| 52 |
+
for cmp_i, s in enumerate(cmp_starts.tolist()):
|
| 53 |
+
a0, a1 = s, s + l
|
| 54 |
+
total = 0
|
| 55 |
+
row_pairs: List[Tuple[int, int]] = []
|
| 56 |
+
for sel_j, t in enumerate(sel_starts.tolist()):
|
| 57 |
+
b0, b1 = t, t + l_sel
|
| 58 |
+
ov = _overlap_len(a0, a1, b0, b1)
|
| 59 |
+
if ov > 0:
|
| 60 |
+
row_pairs.append((sel_j, ov))
|
| 61 |
+
total += ov
|
| 62 |
+
# normalize by total overlap to get fractional weights
|
| 63 |
+
if total > 0:
|
| 64 |
+
for sel_j, ov in row_pairs:
|
| 65 |
+
indices.append(sel_j)
|
| 66 |
+
values.append(ov / total)
|
| 67 |
+
indptr.append(len(indices))
|
| 68 |
+
return (
|
| 69 |
+
torch.tensor(indptr, dtype=torch.int32),
|
| 70 |
+
torch.tensor(indices, dtype=torch.int32),
|
| 71 |
+
torch.tensor(values, dtype=torch.float32),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def build_block_meta(seq_len: int, l: int, d: int, l_sel: int, n_sel: int, w: int) -> BlockMeta:
|
| 76 |
+
if l % d != 0 or l_sel % d != 0:
|
| 77 |
+
# Enforce divisibility by default (per PRD); general overlaps allowed later if needed
|
| 78 |
+
raise ValueError("Require d|l and d|l_sel in M0")
|
| 79 |
+
cmp_starts, sel_starts = build_block_starts(seq_len, l, d, l_sel)
|
| 80 |
+
indptr, indices, values = build_M_csl_csr(seq_len, l, d, l_sel)
|
| 81 |
+
# Build COO from CSR
|
| 82 |
+
rows: List[int] = []
|
| 83 |
+
for r in range(len(cmp_starts)):
|
| 84 |
+
start, end = int(indptr[r].item()), int(indptr[r + 1].item())
|
| 85 |
+
rows.extend([r] * (end - start))
|
| 86 |
+
coo_indices = torch.stack([torch.tensor(rows, dtype=torch.int32), indices.clone()], dim=0)
|
| 87 |
+
return BlockMeta(
|
| 88 |
+
l=l,
|
| 89 |
+
d=d,
|
| 90 |
+
l_sel=l_sel,
|
| 91 |
+
n_sel=n_sel,
|
| 92 |
+
w=w,
|
| 93 |
+
cmp_starts=cmp_starts,
|
| 94 |
+
sel_starts=sel_starts,
|
| 95 |
+
M_csl_indptr=indptr,
|
| 96 |
+
M_csl_indices=indices,
|
| 97 |
+
M_csl_values=values,
|
| 98 |
+
M_csl_coo_indices=coo_indices,
|
| 99 |
+
M_csl_coo_values=values.clone(),
|
| 100 |
+
)
|
nsa/core/collate.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def collate_token_batch(
|
| 8 |
+
sequences: List[List[int]],
|
| 9 |
+
*,
|
| 10 |
+
pad_id: int = 0,
|
| 11 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 12 |
+
"""
|
| 13 |
+
Collate token id sequences (var-length) into padded tensors and masks with label shift.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
sequences: list of token id lists
|
| 17 |
+
pad_id: id used for padding
|
| 18 |
+
Returns:
|
| 19 |
+
input_ids: [B,S_max]
|
| 20 |
+
labels: [B,S_max] (next-token labels; last position masked out)
|
| 21 |
+
attn_mask: [B,S_max] (True for valid tokens)
|
| 22 |
+
loss_mask: [B,S_max] (True for positions to include in loss)
|
| 23 |
+
lengths: [B]
|
| 24 |
+
cu_seqlens:[B+1] cumulative lengths
|
| 25 |
+
"""
|
| 26 |
+
B = len(sequences)
|
| 27 |
+
lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.int32)
|
| 28 |
+
S_max = int(lengths.max().item()) if B > 0 else 0
|
| 29 |
+
input_ids = torch.full((B, S_max), pad_id, dtype=torch.long)
|
| 30 |
+
labels = torch.full((B, S_max), pad_id, dtype=torch.long)
|
| 31 |
+
attn_mask = torch.zeros((B, S_max), dtype=torch.bool)
|
| 32 |
+
loss_mask = torch.zeros((B, S_max), dtype=torch.bool)
|
| 33 |
+
for b, seq in enumerate(sequences):
|
| 34 |
+
L = len(seq)
|
| 35 |
+
if L == 0:
|
| 36 |
+
continue
|
| 37 |
+
input_ids[b, :L] = torch.tensor(seq, dtype=torch.long)
|
| 38 |
+
attn_mask[b, :L] = True
|
| 39 |
+
# next-token labels (shifted left by 1), last token has no next label
|
| 40 |
+
labels[b, : L - 1] = input_ids[b, 1:L]
|
| 41 |
+
loss_mask[b, : L - 1] = True
|
| 42 |
+
# cu_seqlens for varlen APIs
|
| 43 |
+
cu = torch.zeros((B + 1,), dtype=torch.int32)
|
| 44 |
+
cu[1:] = torch.cumsum(lengths, dim=0)
|
| 45 |
+
return input_ids, labels, attn_mask, loss_mask, lengths, cu
|
nsa/core/compress_pool.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from .rope import apply_rope
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def avg_pool_phi_rope_kv(
|
| 11 |
+
K_raw: torch.Tensor,
|
| 12 |
+
V_raw: torch.Tensor,
|
| 13 |
+
l: int,
|
| 14 |
+
d: int,
|
| 15 |
+
pos: Optional[torch.Tensor] = None,
|
| 16 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 17 |
+
# Apply RoPE to K before ϕ; use absolute positions if provided
|
| 18 |
+
S = K_raw.shape[2]
|
| 19 |
+
if pos is None:
|
| 20 |
+
pos = torch.arange(S, device=K_raw.device)
|
| 21 |
+
K_rope = apply_rope(K_raw, pos)
|
| 22 |
+
V_rope = V_raw
|
| 23 |
+
# Expect shapes [B,G,S,D*]
|
| 24 |
+
B, G, S, Dk = K_rope.shape
|
| 25 |
+
# If sequence shorter than kernel, no compressed tokens yet
|
| 26 |
+
if S < l:
|
| 27 |
+
return (
|
| 28 |
+
torch.zeros((B, G, 0, Dk), device=K_rope.device, dtype=K_rope.dtype),
|
| 29 |
+
torch.zeros((B, G, 0, V_rope.shape[-1]), device=V_rope.device, dtype=V_rope.dtype),
|
| 30 |
+
)
|
| 31 |
+
# Unfold over time with stride d and kernel l (causal pooling over past)
|
| 32 |
+
Kf = K_rope.reshape(B * G, S, Dk).transpose(1, 2).unsqueeze(3) # [B*G, Dk, S, 1]
|
| 33 |
+
Vf = V_rope.reshape(B * G, S, -1).transpose(1, 2).unsqueeze(3)
|
| 34 |
+
Kp = F.avg_pool2d(Kf[:, :, :S, :], kernel_size=(l, 1), stride=(d, 1)) # [B*G, Dk, S_cmp, 1]
|
| 35 |
+
Vp = F.avg_pool2d(Vf[:, :, :S, :], kernel_size=(l, 1), stride=(d, 1))
|
| 36 |
+
S_cmp = Kp.shape[2]
|
| 37 |
+
K_cmp = Kp.squeeze(3).transpose(1, 2).reshape(B, G, S_cmp, Dk)
|
| 38 |
+
V_cmp = Vp.squeeze(3).transpose(1, 2).reshape(B, G, S_cmp, V_rope.shape[-1])
|
| 39 |
+
return K_cmp, V_cmp
|
nsa/core/debug.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _flag(name: str) -> bool:
|
| 7 |
+
val = os.getenv(name, "0").lower()
|
| 8 |
+
return val in ("1", "true", "yes")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def debug_enabled() -> bool:
|
| 12 |
+
return _flag("NSA_DEBUG_LOG")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_COUNTS: Dict[str, int] = {}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def log(tag: str, **fields: Any) -> None:
|
| 19 |
+
if not debug_enabled():
|
| 20 |
+
return
|
| 21 |
+
limit_env = os.getenv("NSA_LOG_LIMIT")
|
| 22 |
+
if limit_env is not None:
|
| 23 |
+
try:
|
| 24 |
+
limit = int(limit_env)
|
| 25 |
+
except Exception:
|
| 26 |
+
limit = 0
|
| 27 |
+
if limit > 0:
|
| 28 |
+
cnt = _COUNTS.get(tag, 0)
|
| 29 |
+
if cnt >= limit:
|
| 30 |
+
return
|
| 31 |
+
_COUNTS[tag] = cnt + 1
|
| 32 |
+
parts = [f"{k}={_safe(v)}" for k, v in fields.items()]
|
| 33 |
+
print(f"NSA-LOG {tag} " + " ".join(parts))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _safe(v: Any) -> str:
|
| 37 |
+
try:
|
| 38 |
+
if isinstance(v, int | float | str):
|
| 39 |
+
return str(v)
|
| 40 |
+
if hasattr(v, "shape"):
|
| 41 |
+
return str(tuple(int(x) for x in v.shape))
|
| 42 |
+
return str(v)
|
| 43 |
+
except Exception:
|
| 44 |
+
return "<unrepr>"
|
nsa/core/flags.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def env_true(name: str, default: bool = False) -> bool:
|
| 9 |
+
v = os.getenv(name)
|
| 10 |
+
if v is None:
|
| 11 |
+
return default
|
| 12 |
+
v = v.strip().lower()
|
| 13 |
+
return v in ("1", "true", "yes", "on")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def env_int(name: str, default: int) -> int:
|
| 17 |
+
try:
|
| 18 |
+
return int(os.getenv(name, str(default)))
|
| 19 |
+
except Exception:
|
| 20 |
+
return default
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def is_sm89(device: Optional[torch.device] = None) -> bool:
|
| 24 |
+
dev = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
| 25 |
+
if dev.type != "cuda":
|
| 26 |
+
return False
|
| 27 |
+
try:
|
| 28 |
+
cap = torch.cuda.get_device_capability(dev)
|
| 29 |
+
return cap == (8, 9)
|
| 30 |
+
except Exception:
|
| 31 |
+
return False
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def torch_triton_version_pairing_ok() -> bool:
|
| 35 |
+
try:
|
| 36 |
+
import triton # noqa: F401
|
| 37 |
+
|
| 38 |
+
tv = triton.__version__
|
| 39 |
+
except ImportError:
|
| 40 |
+
tv = "<none>"
|
| 41 |
+
except Exception:
|
| 42 |
+
tv = "<unknown>"
|
| 43 |
+
try:
|
| 44 |
+
tt = torch.__version__
|
| 45 |
+
except Exception:
|
| 46 |
+
tt = "<unknown>"
|
| 47 |
+
# Basic heuristic: 2.2.x ↔ triton 2.2.x; 2.3.x ↔ 2.3.x; 2.4+ ↔ 3.x
|
| 48 |
+
try:
|
| 49 |
+
major_minor = ".".join((tt or "").split("+")[0].split(".")[:2])
|
| 50 |
+
parts = major_minor.split(".")
|
| 51 |
+
t_major = int(parts[0])
|
| 52 |
+
t_minor = int(parts[1])
|
| 53 |
+
if t_major != 2:
|
| 54 |
+
return True # do not gate non-2.x
|
| 55 |
+
if t_minor in (2, 3):
|
| 56 |
+
return tv.startswith(f"{t_minor}.")
|
| 57 |
+
if t_minor >= 4:
|
| 58 |
+
return tv.startswith("3.")
|
| 59 |
+
return True
|
| 60 |
+
except (ValueError, IndexError):
|
| 61 |
+
return True
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def execution_routing_summary() -> dict:
|
| 65 |
+
"""Return a snapshot of routing-related flags and runtime probes."""
|
| 66 |
+
info = {
|
| 67 |
+
"cuda": torch.cuda.is_available(),
|
| 68 |
+
"sm89": is_sm89(),
|
| 69 |
+
"torch": torch.__version__,
|
| 70 |
+
}
|
| 71 |
+
try:
|
| 72 |
+
import triton
|
| 73 |
+
|
| 74 |
+
info["triton"] = triton.__version__
|
| 75 |
+
except Exception:
|
| 76 |
+
info["triton"] = "<none>"
|
| 77 |
+
info["NSA_USE_TRITON_SEL"] = env_true("NSA_USE_TRITON_SEL", False)
|
| 78 |
+
info["NSA_TRITON_SEL_FORCE"] = env_true("NSA_TRITON_SEL_FORCE", False)
|
| 79 |
+
info["NSA_USE_FA2"] = env_true("NSA_USE_FA2", False)
|
| 80 |
+
return info
|
nsa/core/nsa_attention.py
ADDED
|
@@ -0,0 +1,1850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from nsa.cache.kv_cache import NSA_KV
|
| 10 |
+
from nsa.core.attention_kernels import (
|
| 11 |
+
compressed_attention_fa2,
|
| 12 |
+
compressed_attention_fa2_decode,
|
| 13 |
+
grouped_selection_attention,
|
| 14 |
+
grouped_selection_attention_masked,
|
| 15 |
+
grouped_selection_attention_packed,
|
| 16 |
+
sliding_window_attention_fa2,
|
| 17 |
+
sliding_window_attention_fa2_decode,
|
| 18 |
+
)
|
| 19 |
+
from nsa.core.block_index import build_block_meta
|
| 20 |
+
from nsa.core.compress_pool import avg_pool_phi_rope_kv
|
| 21 |
+
from nsa.core.debug import log
|
| 22 |
+
from nsa.core.rope import apply_rope
|
| 23 |
+
from nsa.core.selection_scorer import (
|
| 24 |
+
compute_pcmp_all,
|
| 25 |
+
map_pcmp_to_pslc_batched,
|
| 26 |
+
select_topn_ranges,
|
| 27 |
+
select_topn_ranges_batched,
|
| 28 |
+
verify_mapping_equivalence,
|
| 29 |
+
)
|
| 30 |
+
from nsa.kernels.flash_wrappers import attention_bgh
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class GateMLP(nn.Module):
|
| 34 |
+
def __init__(self, d_k: int, hidden: Optional[int] = None):
|
| 35 |
+
super().__init__()
|
| 36 |
+
hidden = hidden or max(1, d_k // 2)
|
| 37 |
+
self.fc1 = nn.Linear(d_k, hidden)
|
| 38 |
+
self.fc2 = nn.Linear(hidden, 3)
|
| 39 |
+
# Initialize fc2 with small random values to break symmetry and enable learning
|
| 40 |
+
# Use Xavier uniform with reduced scale to start near uniform but allow differentiation
|
| 41 |
+
nn.init.xavier_uniform_(self.fc2.weight, gain=0.1)
|
| 42 |
+
nn.init.zeros_(self.fc2.bias) # Keep bias at zero for initial balance
|
| 43 |
+
# Cache environment variables at init to avoid hot path parsing
|
| 44 |
+
self._force_uniform_gate = os.getenv("NSA_FORCE_UNIFORM_GATE", "0").lower() in (
|
| 45 |
+
"1",
|
| 46 |
+
"true",
|
| 47 |
+
"yes",
|
| 48 |
+
)
|
| 49 |
+
self._force_branch = os.getenv("NSA_FORCE_BRANCH")
|
| 50 |
+
|
| 51 |
+
def forward(self, q_group_pooled: torch.Tensor, tau: float = 1.0) -> torch.Tensor:
|
| 52 |
+
# Uniform gate override for debugging DDP hangs
|
| 53 |
+
if self._force_uniform_gate:
|
| 54 |
+
one_third = 1.0 / 3.0
|
| 55 |
+
shape = (*q_group_pooled.shape[:-1], 3)
|
| 56 |
+
return torch.full(
|
| 57 |
+
shape, one_third, device=q_group_pooled.device, dtype=q_group_pooled.dtype
|
| 58 |
+
)
|
| 59 |
+
fb = self._force_branch
|
| 60 |
+
if fb:
|
| 61 |
+
fb = fb.strip().lower()
|
| 62 |
+
if fb in ("cmp", "sel", "win"):
|
| 63 |
+
idx = 0 if fb == "cmp" else (1 if fb == "sel" else 2)
|
| 64 |
+
one = torch.zeros(
|
| 65 |
+
(*q_group_pooled.shape[:-1], 3),
|
| 66 |
+
device=q_group_pooled.device,
|
| 67 |
+
dtype=q_group_pooled.dtype,
|
| 68 |
+
)
|
| 69 |
+
one[..., idx] = 1.0
|
| 70 |
+
return one
|
| 71 |
+
x = F.silu(self.fc1(q_group_pooled))
|
| 72 |
+
g = self.fc2(x) / max(tau, 1e-6)
|
| 73 |
+
p = F.softmax(g, dim=-1)
|
| 74 |
+
# Hard one-hot if extremely peaked to avoid numerical drift in ablations/tests
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
top2 = torch.topk(g, k=2, dim=-1).values
|
| 77 |
+
peaked = (top2[..., 0] - top2[..., 1]) > 50.0
|
| 78 |
+
if peaked.any():
|
| 79 |
+
one_hot = torch.zeros_like(p)
|
| 80 |
+
idx = torch.argmax(g, dim=-1, keepdim=True)
|
| 81 |
+
one_hot.scatter_(-1, idx, 1.0)
|
| 82 |
+
p = torch.where(peaked.unsqueeze(-1), one_hot, p)
|
| 83 |
+
return p
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _fused_gate_combine_bsg(
|
| 87 |
+
q_gp: torch.Tensor, # [B,S,G,Dk]
|
| 88 |
+
O_cmp: torch.Tensor, # [B,S,G,h,Dv]
|
| 89 |
+
O_sel: torch.Tensor, # [B,S,G,h,Dv]
|
| 90 |
+
O_win: torch.Tensor, # [B,S,G,h,Dv]
|
| 91 |
+
fc1_w: torch.Tensor,
|
| 92 |
+
fc1_b: torch.Tensor | None,
|
| 93 |
+
fc2_w: torch.Tensor,
|
| 94 |
+
fc2_b: torch.Tensor | None,
|
| 95 |
+
tau: float,
|
| 96 |
+
) -> torch.Tensor:
|
| 97 |
+
import torch.nn.functional as _F
|
| 98 |
+
x = _F.silu(_F.linear(q_gp, fc1_w, fc1_b))
|
| 99 |
+
g = _F.linear(x, fc2_w, fc2_b) / max(tau, 1e-6)
|
| 100 |
+
p = _F.softmax(g, dim=-1)
|
| 101 |
+
w_cmp = p[..., 0:1].unsqueeze(-1)
|
| 102 |
+
w_sel = p[..., 1:2].unsqueeze(-1)
|
| 103 |
+
w_win = p[..., 2:3].unsqueeze(-1)
|
| 104 |
+
return w_cmp * O_cmp + w_sel * O_sel + w_win * O_win
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _fused_gate_combine_bg(
|
| 108 |
+
q_gp: torch.Tensor, # [B,G,Dk]
|
| 109 |
+
O_cmp: torch.Tensor, # [B,G,h,Dv]
|
| 110 |
+
O_sel: torch.Tensor, # [B,G,h,Dv]
|
| 111 |
+
O_win: torch.Tensor, # [B,G,h,Dv]
|
| 112 |
+
fc1_w: torch.Tensor,
|
| 113 |
+
fc1_b: torch.Tensor | None,
|
| 114 |
+
fc2_w: torch.Tensor,
|
| 115 |
+
fc2_b: torch.Tensor | None,
|
| 116 |
+
tau: float,
|
| 117 |
+
) -> torch.Tensor:
|
| 118 |
+
import torch.nn.functional as _F
|
| 119 |
+
x = _F.silu(_F.linear(q_gp, fc1_w, fc1_b))
|
| 120 |
+
g = _F.linear(x, fc2_w, fc2_b) / max(tau, 1e-6)
|
| 121 |
+
p = _F.softmax(g, dim=-1)
|
| 122 |
+
w_cmp = p[..., 0:1].unsqueeze(-1)
|
| 123 |
+
w_sel = p[..., 1:2].unsqueeze(-1)
|
| 124 |
+
w_win = p[..., 2:3].unsqueeze(-1)
|
| 125 |
+
return w_cmp * O_cmp + w_sel * O_sel + w_win * O_win
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _compute_gate_stats(gates: torch.Tensor) -> dict:
|
| 129 |
+
"""Compute gate health statistics for monitoring.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
gates: Gate probabilities [B, S, G, 3] or [B, G, 3]
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
Dict with gate statistics: entropy, max_gate, branch_shares
|
| 136 |
+
"""
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
# Flatten to [*, 3] for consistent computation
|
| 139 |
+
gates_flat = gates.view(-1, 3)
|
| 140 |
+
|
| 141 |
+
# Gate entropy (should be > 0.5 for healthy mixing)
|
| 142 |
+
entropy = -(gates_flat * (gates_flat + 1e-8).log()).sum(dim=-1)
|
| 143 |
+
mean_entropy = entropy.mean().item()
|
| 144 |
+
min_entropy = entropy.min().item()
|
| 145 |
+
|
| 146 |
+
# Max gate value (should be < 0.9 to avoid collapse)
|
| 147 |
+
max_gate = gates_flat.max(dim=-1)[0]
|
| 148 |
+
mean_max_gate = max_gate.mean().item()
|
| 149 |
+
max_max_gate = max_gate.max().item()
|
| 150 |
+
|
| 151 |
+
# Branch usage shares (should be balanced)
|
| 152 |
+
branch_shares = gates_flat.mean(dim=0).tolist() # [cmp, sel, win]
|
| 153 |
+
|
| 154 |
+
# Gate collapse detection (entropy < 0.1 and max_gate > 0.95)
|
| 155 |
+
collapsed = (entropy < 0.1) & (max_gate > 0.95)
|
| 156 |
+
collapse_fraction = collapsed.float().mean().item()
|
| 157 |
+
|
| 158 |
+
return {
|
| 159 |
+
"entropy_mean": mean_entropy,
|
| 160 |
+
"entropy_min": min_entropy,
|
| 161 |
+
"max_gate_mean": mean_max_gate,
|
| 162 |
+
"max_gate_max": max_max_gate,
|
| 163 |
+
"branch_shares": branch_shares, # [cmp, sel, win]
|
| 164 |
+
"collapse_fraction": collapse_fraction,
|
| 165 |
+
"total_gates": len(gates_flat),
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class NSAAttention(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Native Sparse Attention (NSA) module (M0 steel-thread).
|
| 172 |
+
|
| 173 |
+
Shapes:
|
| 174 |
+
- Input x (prefill): [B,S,dim]; x (decode): [B,1,dim]
|
| 175 |
+
- Heads: n_heads, grouped into n_kv_groups with h_per_group = n_heads // n_kv_groups
|
| 176 |
+
- Projections produce:
|
| 177 |
+
- Q: [B,S,G,h,Dk]
|
| 178 |
+
- K/V per-branch: [B,G,S,D*]
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
- out: [B,S,dim] (prefill) or [B,1,dim] (decode)
|
| 182 |
+
- kv: updated NSA_KV caches
|
| 183 |
+
|
| 184 |
+
Notes:
|
| 185 |
+
- M0 constraints: SDPA-only, fixed sequence length in tests, deterministic.
|
| 186 |
+
- Masked/packed fast paths are env-gated with `NSA_FORCE_PARITY` fallback.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(
|
| 190 |
+
self,
|
| 191 |
+
dim: int,
|
| 192 |
+
n_heads: int,
|
| 193 |
+
n_kv_groups: int,
|
| 194 |
+
d_k: int,
|
| 195 |
+
d_v: int,
|
| 196 |
+
l: int = 32,
|
| 197 |
+
d: int = 16,
|
| 198 |
+
l_sel: int = 64,
|
| 199 |
+
n_sel: int = 16,
|
| 200 |
+
w: int = 512,
|
| 201 |
+
phi: str = "avg",
|
| 202 |
+
gate_hidden: Optional[int] = None,
|
| 203 |
+
gate_temp: float = 1.0,
|
| 204 |
+
rope_impl: str = "llama",
|
| 205 |
+
use_flash: bool = True,
|
| 206 |
+
use_triton_sel: bool = False,
|
| 207 |
+
) -> None:
|
| 208 |
+
super().__init__()
|
| 209 |
+
assert n_heads % n_kv_groups == 0, "heads must be divisible by kv groups"
|
| 210 |
+
# M0 config validation (PRD enforces divisibility)
|
| 211 |
+
if l % d != 0 or l_sel % d != 0:
|
| 212 |
+
raise ValueError("M0 requires d|l and d|l_sel; set valid block sizes/stride.")
|
| 213 |
+
self.dim = dim
|
| 214 |
+
self.n_heads = n_heads
|
| 215 |
+
self.n_kv_groups = n_kv_groups
|
| 216 |
+
self.h_per_group = n_heads // n_kv_groups
|
| 217 |
+
self.d_k = d_k
|
| 218 |
+
self.d_v = d_v
|
| 219 |
+
self.l = l
|
| 220 |
+
self.d = d
|
| 221 |
+
self.l_sel = l_sel
|
| 222 |
+
self.n_sel = n_sel
|
| 223 |
+
self.w = w
|
| 224 |
+
self.gate_temp = gate_temp
|
| 225 |
+
self.phi_type = (phi or "avg").lower()
|
| 226 |
+
|
| 227 |
+
# Gate health tracking for M8 monitoring
|
| 228 |
+
self._last_gate_stats = None
|
| 229 |
+
# M8: Selection length stats for monitoring (updated each forward)
|
| 230 |
+
self._last_sel_stats: Optional[dict] = None
|
| 231 |
+
|
| 232 |
+
# M8: Fallback counters for routing monitoring
|
| 233 |
+
self._fallback_counters = {
|
| 234 |
+
"selection_triton_fails": 0,
|
| 235 |
+
"selection_cuda_fails": 0,
|
| 236 |
+
"selection_pack_fails": 0,
|
| 237 |
+
"selection_mask_fails": 0,
|
| 238 |
+
"compressed_fa2_fails": 0,
|
| 239 |
+
"sliding_fa2_fails": 0,
|
| 240 |
+
"total_fallbacks": 0,
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
# RoPE scaling and prefill tiling for long-context demos (env-overridable)
|
| 244 |
+
try:
|
| 245 |
+
rs = float(os.getenv("NSA_ROPE_SCALE", "1.0"))
|
| 246 |
+
if not (rs > 0.0) or rs != rs: # require positive finite
|
| 247 |
+
rs = 1.0
|
| 248 |
+
self.rope_scale = rs
|
| 249 |
+
except ValueError:
|
| 250 |
+
self.rope_scale = 1.0
|
| 251 |
+
try:
|
| 252 |
+
pt = int(os.getenv("NSA_PREFILL_TILE", "0"))
|
| 253 |
+
if pt < 0:
|
| 254 |
+
pt = 0
|
| 255 |
+
self.prefill_tile = pt
|
| 256 |
+
except ValueError:
|
| 257 |
+
self.prefill_tile = 0
|
| 258 |
+
# Projections
|
| 259 |
+
self.W_Q = nn.Linear(dim, n_heads * d_k, bias=False)
|
| 260 |
+
self.W_K_sel = nn.Linear(dim, n_kv_groups * d_k, bias=False)
|
| 261 |
+
self.W_V_sel = nn.Linear(dim, n_kv_groups * d_v, bias=False)
|
| 262 |
+
self.W_K_win = nn.Linear(dim, n_kv_groups * d_k, bias=False)
|
| 263 |
+
self.W_V_win = nn.Linear(dim, n_kv_groups * d_v, bias=False)
|
| 264 |
+
self.W_K_cmp = nn.Linear(dim, n_kv_groups * d_k, bias=False)
|
| 265 |
+
self.W_V_cmp = nn.Linear(dim, n_kv_groups * d_v, bias=False)
|
| 266 |
+
self.out = nn.Linear(n_heads * d_v, dim, bias=False)
|
| 267 |
+
self.gate = GateMLP(d_k, gate_hidden)
|
| 268 |
+
# Default FA-2 usage (can be overridden by env flags)
|
| 269 |
+
self.use_flash_default = use_flash
|
| 270 |
+
# One-time SDPA backend audit flag
|
| 271 |
+
self._sdpa_audited = False
|
| 272 |
+
# Selection Triton toggle (M4)
|
| 273 |
+
self.use_triton_sel = use_triton_sel
|
| 274 |
+
# Cache environment variables to avoid repeated parsing in hot path
|
| 275 |
+
self._cache_env_vars()
|
| 276 |
+
# Optional learnable ϕ via depthwise Conv1d over time with kernel l and stride d
|
| 277 |
+
# Initialize to average pooling for parity with M0
|
| 278 |
+
self.phi_k_conv: Optional[nn.Conv1d]
|
| 279 |
+
self.phi_v_conv: Optional[nn.Conv1d]
|
| 280 |
+
if self.phi_type == "mlp":
|
| 281 |
+
self.phi_k_conv = nn.Conv1d(
|
| 282 |
+
self.d_k, self.d_k, kernel_size=self.l, stride=self.d, groups=self.d_k, bias=False
|
| 283 |
+
)
|
| 284 |
+
self.phi_v_conv = nn.Conv1d(
|
| 285 |
+
self.d_v, self.d_v, kernel_size=self.l, stride=self.d, groups=self.d_v, bias=False
|
| 286 |
+
)
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
self.phi_k_conv.weight.fill_(1.0 / float(self.l))
|
| 289 |
+
self.phi_v_conv.weight.fill_(1.0 / float(self.l))
|
| 290 |
+
else:
|
| 291 |
+
self.phi_k_conv = None
|
| 292 |
+
self.phi_v_conv = None
|
| 293 |
+
|
| 294 |
+
def _cache_env_vars(self) -> None:
|
| 295 |
+
"""Cache environment variables to avoid repeated parsing in hot path."""
|
| 296 |
+
|
| 297 |
+
def parse_bool(val: str, default: str = "0") -> bool:
|
| 298 |
+
return os.getenv(val, default).lower() in ("1", "true", "yes")
|
| 299 |
+
|
| 300 |
+
# Cache frequently accessed environment variables
|
| 301 |
+
# Raw parsed flags
|
| 302 |
+
self._env_cache = {
|
| 303 |
+
"static": parse_bool("NSA_ENV_STATIC", "0"),
|
| 304 |
+
"force_uniform_gate": parse_bool("NSA_FORCE_UNIFORM_GATE", "0"),
|
| 305 |
+
"force_branch": os.getenv("NSA_FORCE_BRANCH"),
|
| 306 |
+
"prefill_batched": parse_bool("NSA_PREFILL_BATCHED", "0"),
|
| 307 |
+
"strict_asserts": parse_bool("NSA_STRICT_ASSERTS", "0"),
|
| 308 |
+
"force_parity": parse_bool("NSA_FORCE_PARITY", "0"),
|
| 309 |
+
"use_sel_pack": parse_bool("NSA_USE_SEL_PACK", "1"),
|
| 310 |
+
"use_triton_sel": parse_bool("NSA_USE_TRITON_SEL", "0") or self.use_triton_sel,
|
| 311 |
+
"use_cuda_sel": parse_bool("NSA_SEL_CUDA", "0"),
|
| 312 |
+
"use_sel_varlen": parse_bool("NSA_USE_SEL_VARLEN", "0"),
|
| 313 |
+
# Hard override to force masked selection path (debug/triage)
|
| 314 |
+
"force_sel_mask": parse_bool("NSA_FORCE_SEL_MASK", "0"),
|
| 315 |
+
"fa2_all": parse_bool("NSA_USE_FA2", "0"),
|
| 316 |
+
"fa2_win": parse_bool("NSA_USE_FA2_WIN", "0"),
|
| 317 |
+
"fa2_cmp": parse_bool("NSA_USE_FA2_CMP", "0"),
|
| 318 |
+
"use_sel_mask": parse_bool("NSA_USE_SEL_MASK", "0"),
|
| 319 |
+
"use_cmp_mask": parse_bool("NSA_USE_CMP_MASK", "1"),
|
| 320 |
+
"use_win_mask": parse_bool("NSA_USE_WIN_MASK", "1"),
|
| 321 |
+
"verify_eq9": parse_bool("NSA_VERIFY_EQ9_MAPPING", "0"),
|
| 322 |
+
"stopgrad_gates": parse_bool("NSA_STOPGRAD_GATES", "0"),
|
| 323 |
+
"nvtx": parse_bool("NSA_NVTX", "0"),
|
| 324 |
+
"debug_compare": parse_bool("NSA_DEBUG_COMPARE", "0"),
|
| 325 |
+
"gate_compile": parse_bool("NSA_GATE_COMPILE", "0"),
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
# Detect whether env overrides were explicitly provided so we can honor hard-disable
|
| 329 |
+
fa2_all_set = "NSA_USE_FA2" in os.environ
|
| 330 |
+
fa2_win_set = "NSA_USE_FA2_WIN" in os.environ
|
| 331 |
+
fa2_cmp_set = "NSA_USE_FA2_CMP" in os.environ
|
| 332 |
+
self._env_cache.update(
|
| 333 |
+
{
|
| 334 |
+
"fa2_all_set": fa2_all_set,
|
| 335 |
+
"fa2_win_set": fa2_win_set,
|
| 336 |
+
"fa2_cmp_set": fa2_cmp_set,
|
| 337 |
+
}
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Compute effective FA-2 gating with sensible defaults and hard-disable semantics
|
| 341 |
+
fa2_all_env = self._env_cache["fa2_all"]
|
| 342 |
+
fa2_win_env = self._env_cache["fa2_win"]
|
| 343 |
+
fa2_cmp_env = self._env_cache["fa2_cmp"]
|
| 344 |
+
|
| 345 |
+
# Defaults when no explicit env flags are provided:
|
| 346 |
+
# - Enable compressed FA‑2 by default (robustly capability-gated at call sites)
|
| 347 |
+
# - Keep sliding FA‑2 off by default due to API semantics
|
| 348 |
+
# - Do not use the global "all" default to avoid inadvertently enabling sliding
|
| 349 |
+
if not (fa2_all_set or fa2_win_set or fa2_cmp_set):
|
| 350 |
+
fa2_all_eff = False
|
| 351 |
+
fa2_win_eff = False
|
| 352 |
+
fa2_cmp_eff = True
|
| 353 |
+
else:
|
| 354 |
+
# If NSA_USE_FA2 not set, fall back to model default; else honor explicit value
|
| 355 |
+
fa2_all_eff = self.use_flash_default if not fa2_all_set else fa2_all_env
|
| 356 |
+
|
| 357 |
+
# If global is explicitly set to 0, that hard-disables branch flags too
|
| 358 |
+
if fa2_all_set and not fa2_all_env:
|
| 359 |
+
fa2_win_eff = False
|
| 360 |
+
fa2_cmp_eff = False
|
| 361 |
+
else:
|
| 362 |
+
# Branch-specific flags only take effect if explicitly set; otherwise default off
|
| 363 |
+
fa2_win_eff = fa2_win_env if fa2_win_set else False
|
| 364 |
+
fa2_cmp_eff = fa2_cmp_env if fa2_cmp_set else False
|
| 365 |
+
|
| 366 |
+
self._env_cache.update(
|
| 367 |
+
{
|
| 368 |
+
"fa2_all_eff": fa2_all_eff,
|
| 369 |
+
"fa2_win_eff": fa2_win_eff,
|
| 370 |
+
"fa2_cmp_eff": fa2_cmp_eff,
|
| 371 |
+
}
|
| 372 |
+
)
|
| 373 |
+
# Parse numeric values
|
| 374 |
+
try:
|
| 375 |
+
self._rope_scale = float(os.getenv("NSA_ROPE_SCALE", "1.0"))
|
| 376 |
+
if not (self._rope_scale > 0.0) or self._rope_scale != self._rope_scale:
|
| 377 |
+
self._rope_scale = 1.0
|
| 378 |
+
except (ValueError, TypeError):
|
| 379 |
+
self._rope_scale = 1.0
|
| 380 |
+
|
| 381 |
+
try:
|
| 382 |
+
self._prefill_tile = int(os.getenv("NSA_PREFILL_TILE", "0"))
|
| 383 |
+
if self._prefill_tile < 0:
|
| 384 |
+
self._prefill_tile = 0
|
| 385 |
+
except (ValueError, TypeError):
|
| 386 |
+
self._prefill_tile = 0
|
| 387 |
+
# Fused gate combine (lazy-compiled)
|
| 388 |
+
self._gate_fused_bsg = None
|
| 389 |
+
self._gate_fused_bg = None
|
| 390 |
+
|
| 391 |
+
def _shape_q(self, Q: torch.Tensor, B: int, S: int) -> torch.Tensor:
|
| 392 |
+
Q = Q.view(B, S, self.n_heads, self.d_k)
|
| 393 |
+
# group-major: [B,S,G,h,Dk]
|
| 394 |
+
G = self.n_kv_groups
|
| 395 |
+
h = self.h_per_group
|
| 396 |
+
return Q.view(B, S, G, h, self.d_k)
|
| 397 |
+
|
| 398 |
+
def _shape_kv(self, X: torch.Tensor, B: int, S: int) -> torch.Tensor:
|
| 399 |
+
G = self.n_kv_groups
|
| 400 |
+
return X.view(B, S, G, -1).permute(0, 2, 1, 3).contiguous() # [B,G,S,D*]
|
| 401 |
+
|
| 402 |
+
def get_gate_stats(self) -> Optional[dict]:
|
| 403 |
+
"""Get the most recent gate statistics for monitoring.
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
Dict with gate health metrics or None if no recent computation
|
| 407 |
+
"""
|
| 408 |
+
return self._last_gate_stats
|
| 409 |
+
|
| 410 |
+
def get_fallback_counters(self) -> dict:
|
| 411 |
+
"""Get fallback counters for routing monitoring.
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
Dict with fallback counts per implementation type
|
| 415 |
+
"""
|
| 416 |
+
return self._fallback_counters.copy()
|
| 417 |
+
|
| 418 |
+
def get_selection_stats(self) -> Optional[dict]:
|
| 419 |
+
"""Return last computed selection length statistics, if available.
|
| 420 |
+
|
| 421 |
+
Keys:
|
| 422 |
+
- k_mean: mean selected K per row (float)
|
| 423 |
+
- k_max: max selected K in batch (int)
|
| 424 |
+
- rows: number of (B,S,G) rows aggregated (int)
|
| 425 |
+
- pct_at_max: fraction of rows equal to k_max (float)
|
| 426 |
+
- l_sel: configured selection block size (int)
|
| 427 |
+
- n_sel: configured top-n selection blocks (int)
|
| 428 |
+
"""
|
| 429 |
+
return self._last_sel_stats
|
| 430 |
+
|
| 431 |
+
def reset_fallback_counters(self) -> dict:
|
| 432 |
+
"""Reset fallback counters and return the previous values.
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
Dict with fallback counts before reset
|
| 436 |
+
"""
|
| 437 |
+
prev_counters = self._fallback_counters.copy()
|
| 438 |
+
for key in self._fallback_counters:
|
| 439 |
+
self._fallback_counters[key] = 0
|
| 440 |
+
return prev_counters
|
| 441 |
+
|
| 442 |
+
def _update_gate_stats(self, gates: torch.Tensor) -> None:
|
| 443 |
+
"""Update stored gate statistics for monitoring."""
|
| 444 |
+
try:
|
| 445 |
+
self._last_gate_stats = _compute_gate_stats(gates)
|
| 446 |
+
except Exception as e:
|
| 447 |
+
log("warn.gate_stats_fail", error=str(e))
|
| 448 |
+
self._last_gate_stats = None
|
| 449 |
+
|
| 450 |
+
def _update_sel_stats_from_ranges(self, ranges: torch.Tensor) -> None:
|
| 451 |
+
"""Compute and store selection statistics from [B,*,G,n,2] ranges tensor."""
|
| 452 |
+
try:
|
| 453 |
+
if ranges is None or ranges.numel() == 0:
|
| 454 |
+
self._last_sel_stats = {
|
| 455 |
+
"k_mean": 0.0,
|
| 456 |
+
"k_max": 0,
|
| 457 |
+
"rows": 0,
|
| 458 |
+
"pct_at_max": 0.0,
|
| 459 |
+
"l_sel": int(self.l_sel),
|
| 460 |
+
"n_sel": int(self.n_sel),
|
| 461 |
+
}
|
| 462 |
+
return
|
| 463 |
+
# ranges: [B, T, G, n, 2] or [B, G, n, 2]
|
| 464 |
+
if ranges.dim() == 5:
|
| 465 |
+
B, T, G, n, _ = ranges.shape
|
| 466 |
+
rs = ranges
|
| 467 |
+
rows = B * T * G
|
| 468 |
+
# [B,T,G,n]
|
| 469 |
+
lengths = (rs[..., 1] - rs[..., 0]).clamp_min(0)
|
| 470 |
+
# Sum across n ranges → [B,T,G]
|
| 471 |
+
L = lengths.sum(dim=-1).to(torch.int64)
|
| 472 |
+
elif ranges.dim() == 4:
|
| 473 |
+
B, G, n, _ = ranges.shape
|
| 474 |
+
rs = ranges
|
| 475 |
+
rows = B * G
|
| 476 |
+
lengths = (rs[..., 1] - rs[..., 0]).clamp_min(0)
|
| 477 |
+
L = lengths.sum(dim=-1).to(torch.int64) # [B,G]
|
| 478 |
+
else:
|
| 479 |
+
# Unknown shape; skip
|
| 480 |
+
return
|
| 481 |
+
if L.numel() == 0:
|
| 482 |
+
k_mean = 0.0
|
| 483 |
+
k_max = 0
|
| 484 |
+
pct_at_max = 0.0
|
| 485 |
+
else:
|
| 486 |
+
k_max = int(L.max().item())
|
| 487 |
+
k_mean = float(L.to(torch.float32).mean().item())
|
| 488 |
+
if k_max > 0:
|
| 489 |
+
pct_at_max = float((L == k_max).to(torch.float32).mean().item())
|
| 490 |
+
else:
|
| 491 |
+
pct_at_max = 0.0
|
| 492 |
+
self._last_sel_stats = {
|
| 493 |
+
"k_mean": k_mean,
|
| 494 |
+
"k_max": k_max,
|
| 495 |
+
"rows": int(rows),
|
| 496 |
+
"pct_at_max": pct_at_max,
|
| 497 |
+
"l_sel": int(self.l_sel),
|
| 498 |
+
"n_sel": int(self.n_sel),
|
| 499 |
+
}
|
| 500 |
+
except Exception as e:
|
| 501 |
+
log("warn.sel_stats_fail", error=str(e))
|
| 502 |
+
self._last_sel_stats = None
|
| 503 |
+
|
| 504 |
+
def forward(self, x: torch.Tensor, kv: NSA_KV, *, prefill: bool) -> tuple[torch.Tensor, NSA_KV]:
|
| 505 |
+
"""
|
| 506 |
+
Forward pass.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
x: [B,S,dim] if prefill else [B,1,dim]
|
| 510 |
+
kv: NSA_KV caches (updated in-place per branch)
|
| 511 |
+
prefill: True for batched prefill, False for single-token decode
|
| 512 |
+
|
| 513 |
+
Returns:
|
| 514 |
+
(out, kv): out is [B,S,dim] (prefill) or [B,1,dim] (decode)
|
| 515 |
+
"""
|
| 516 |
+
# x: [B,S,dim] (prefill) or [B,1,dim] (decode)
|
| 517 |
+
B, S, _ = x.shape
|
| 518 |
+
assert x.dim() == 3, "x must be [B,S,dim]"
|
| 519 |
+
assert self.n_heads % self.n_kv_groups == 0, "n_heads must be divisible by n_kv_groups"
|
| 520 |
+
# Strict assertions may introduce GPU syncs; gate via env for tests/smokes
|
| 521 |
+
strict_asserts = self._env_cache.get("strict_asserts", False)
|
| 522 |
+
|
| 523 |
+
# M8: Assert causal masking - enforce mode constraints
|
| 524 |
+
if prefill:
|
| 525 |
+
assert S > 0, f"Prefill mode requires S > 0, got S={S}"
|
| 526 |
+
else:
|
| 527 |
+
assert S == 1, (
|
| 528 |
+
f"Decode mode requires S=1 (single token), got S={S}. "
|
| 529 |
+
f"This ensures proper causal ordering in decode steps."
|
| 530 |
+
)
|
| 531 |
+
if prefill:
|
| 532 |
+
# Optional: route prefill via single-token decode steps to support very long contexts safely.
|
| 533 |
+
if getattr(self, "prefill_tile", 0) and self.prefill_tile > 0:
|
| 534 |
+
return self._forward_prefill_via_decode(x, kv)
|
| 535 |
+
use_batched = self._env_cache.get("prefill_batched", False)
|
| 536 |
+
if use_batched:
|
| 537 |
+
return self._forward_prefill_batched(x, kv)
|
| 538 |
+
else:
|
| 539 |
+
return self._forward_prefill_sequential(x, kv)
|
| 540 |
+
else:
|
| 541 |
+
# Projections
|
| 542 |
+
# Compute absolute position offset from existing cache length for RoPE on Q
|
| 543 |
+
t_prev = kv.K_sel.shape[2] if hasattr(kv, "K_sel") else 0
|
| 544 |
+
Q_lin = self._shape_q(self.W_Q(x), B, S) # [B,S,G,h,Dk]
|
| 545 |
+
# Apply RoPE to Q with absolute positions (decode)
|
| 546 |
+
pos = torch.arange(t_prev, t_prev + S, device=x.device)
|
| 547 |
+
Q = apply_rope(
|
| 548 |
+
Q_lin.view(B, S, self.n_heads, self.d_k).reshape(B, S, self.n_heads * self.d_k),
|
| 549 |
+
pos,
|
| 550 |
+
scale=getattr(self, "rope_scale", 1.0),
|
| 551 |
+
)
|
| 552 |
+
Q = Q.view(B, S, self.n_heads, self.d_k)
|
| 553 |
+
G = self.n_kv_groups
|
| 554 |
+
h = self.h_per_group
|
| 555 |
+
Q = Q.view(B, S, G, h, self.d_k)
|
| 556 |
+
K_sel = self._shape_kv(self.W_K_sel(x), B, S)
|
| 557 |
+
V_sel = self._shape_kv(self.W_V_sel(x), B, S)
|
| 558 |
+
K_win = self._shape_kv(self.W_K_win(x), B, S)
|
| 559 |
+
V_win = self._shape_kv(self.W_V_win(x), B, S)
|
| 560 |
+
K_cmp_raw = self._shape_kv(self.W_K_cmp(x), B, S)
|
| 561 |
+
V_cmp_raw = self._shape_kv(self.W_V_cmp(x), B, S)
|
| 562 |
+
|
| 563 |
+
# Apply RoPE to K for selection/sliding branches using absolute position of the new token(s)
|
| 564 |
+
# Determine current token index before appending to caches
|
| 565 |
+
t_prev = kv.K_sel.shape[2] if hasattr(kv, "K_sel") else 0
|
| 566 |
+
pos_k = torch.arange(t_prev, t_prev + S, device=x.device)
|
| 567 |
+
K_sel = apply_rope(K_sel, pos_k, scale=getattr(self, "rope_scale", 1.0))
|
| 568 |
+
K_win = apply_rope(K_win, pos_k, scale=getattr(self, "rope_scale", 1.0))
|
| 569 |
+
|
| 570 |
+
# decode step: append raw tokens and window, emit compressed every d after warmup l
|
| 571 |
+
kv.update_selection_raw(K_sel, V_sel)
|
| 572 |
+
kv.update_window(K_win, V_win, self.w)
|
| 573 |
+
if not hasattr(kv, "K_cmp_raw_seq"):
|
| 574 |
+
kv.K_cmp_raw_seq = K_cmp_raw[:, :, :0]
|
| 575 |
+
kv.V_cmp_raw_seq = V_cmp_raw[:, :, :0]
|
| 576 |
+
kv.reads_pred = torch.zeros((0,), dtype=torch.int64, device=x.device)
|
| 577 |
+
kv.reads_act_total = torch.zeros((0,), dtype=torch.int64, device=x.device)
|
| 578 |
+
kv.reads_act_sel = torch.zeros((0,), dtype=torch.int64, device=x.device)
|
| 579 |
+
kv.reads_act_cmp = torch.zeros((0,), dtype=torch.int64, device=x.device)
|
| 580 |
+
kv.reads_act_win = torch.zeros((0,), dtype=torch.int64, device=x.device)
|
| 581 |
+
kv.append_cmp_raw(K_cmp_raw, V_cmp_raw)
|
| 582 |
+
S_raw = kv.K_cmp_raw_seq.shape[2]
|
| 583 |
+
if S_raw >= self.l and (S_raw - self.l) % self.d == 0:
|
| 584 |
+
# Emit compressed token from the last l raw tokens
|
| 585 |
+
K_last = kv.K_cmp_raw_seq[:, :, S_raw - self.l : S_raw, :]
|
| 586 |
+
V_last = kv.V_cmp_raw_seq[:, :, S_raw - self.l : S_raw, :]
|
| 587 |
+
pos_last = torch.arange(S_raw - self.l, S_raw, device=x.device)
|
| 588 |
+
if self.phi_type == "mlp":
|
| 589 |
+
K_cmp_new, V_cmp_new = self._phi_apply_last(K_last, V_last, pos_last)
|
| 590 |
+
else:
|
| 591 |
+
K_cmp_new, V_cmp_new = avg_pool_phi_rope_kv(
|
| 592 |
+
K_last, V_last, self.l, self.d, pos=pos_last
|
| 593 |
+
)
|
| 594 |
+
kv.update_compressed(
|
| 595 |
+
torch.cat([kv.K_cmp, K_cmp_new], dim=2) if kv.K_cmp.numel() else K_cmp_new,
|
| 596 |
+
torch.cat([kv.V_cmp, V_cmp_new], dim=2) if kv.V_cmp.numel() else V_cmp_new,
|
| 597 |
+
self.l,
|
| 598 |
+
self.d,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# Ensure block metadata exists and covers current token index for selection (expand if needed)
|
| 602 |
+
t_token = kv.K_sel.shape[2] - 1
|
| 603 |
+
if not hasattr(kv, "meta") or kv.meta.sel_starts.numel() == 0:
|
| 604 |
+
kv.meta = build_block_meta(
|
| 605 |
+
seq_len=max(t_token + 1, self.l_sel),
|
| 606 |
+
l=self.l,
|
| 607 |
+
d=self.d,
|
| 608 |
+
l_sel=self.l_sel,
|
| 609 |
+
n_sel=self.n_sel,
|
| 610 |
+
w=self.w,
|
| 611 |
+
)
|
| 612 |
+
else:
|
| 613 |
+
# If current t exceeds covered selection range, rebuild meta with expanded seq_len
|
| 614 |
+
sel_max_end = (
|
| 615 |
+
int(kv.meta.sel_starts[-1].item()) + kv.meta.l_sel
|
| 616 |
+
if kv.meta.sel_starts.numel() > 0
|
| 617 |
+
else 0
|
| 618 |
+
)
|
| 619 |
+
if (t_token + 1) > sel_max_end:
|
| 620 |
+
kv.meta = build_block_meta(
|
| 621 |
+
seq_len=t_token + 1,
|
| 622 |
+
l=self.l,
|
| 623 |
+
d=self.d,
|
| 624 |
+
l_sel=self.l_sel,
|
| 625 |
+
n_sel=self.n_sel,
|
| 626 |
+
w=self.w,
|
| 627 |
+
)
|
| 628 |
+
# Append predicted reads per formula for this step
|
| 629 |
+
num_cmp = 0 if S_raw < self.l else (S_raw - self.l) // self.d + 1
|
| 630 |
+
reads = num_cmp + self.n_sel * self.l_sel + min(self.w, S_raw)
|
| 631 |
+
kv.append_reads_pred(reads)
|
| 632 |
+
# Append actual reads equal to formula in M0
|
| 633 |
+
kv.append_reads_actual(reads, self.n_sel * self.l_sel, num_cmp, min(self.w, S_raw))
|
| 634 |
+
log(
|
| 635 |
+
"decode.reads",
|
| 636 |
+
S_raw=int(S_raw),
|
| 637 |
+
num_cmp=int(num_cmp),
|
| 638 |
+
sel=int(self.n_sel * self.l_sel),
|
| 639 |
+
win=int(min(self.w, S_raw)),
|
| 640 |
+
total=int(reads),
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
scale = 1.0 / (self.d_k**0.5)
|
| 644 |
+
# Compute p_cmp only for this step (S is 1 in decode)
|
| 645 |
+
K_cmp_full = kv.K_cmp
|
| 646 |
+
p_cmp_all = compute_pcmp_all(Q, K_cmp_full, scale)
|
| 647 |
+
# Per-token outputs (S should be 1 in decode)
|
| 648 |
+
outs = []
|
| 649 |
+
# Use cached environment variables
|
| 650 |
+
env = self._env_cache
|
| 651 |
+
|
| 652 |
+
for t in range(S):
|
| 653 |
+
p_slc_all = map_pcmp_to_pslc_batched(p_cmp_all[:, t : t + 1], kv.meta)
|
| 654 |
+
|
| 655 |
+
# M8: Optional Eq.9 verification in decode
|
| 656 |
+
if self._env_cache.get("verify_eq9", False):
|
| 657 |
+
is_equiv, details = verify_mapping_equivalence(p_cmp_all[:, t : t + 1], kv.meta)
|
| 658 |
+
if not is_equiv:
|
| 659 |
+
log(
|
| 660 |
+
"error.eq9_verification_failed_decode",
|
| 661 |
+
msg="Eq.9 mapping verification failed in decode",
|
| 662 |
+
step=t,
|
| 663 |
+
**details,
|
| 664 |
+
)
|
| 665 |
+
p_grp = p_slc_all.sum(dim=3).squeeze(1) # [B,G,S_sel]
|
| 666 |
+
current_pos = kv.K_sel.shape[2] - 1 # Current token position (0-indexed)
|
| 667 |
+
sel_ranges = select_topn_ranges(p_grp, kv.meta, self.n_sel, current_pos, True, 2)
|
| 668 |
+
|
| 669 |
+
# M8: Assert causal masking - selection ranges cannot include future tokens
|
| 670 |
+
if strict_asserts and sel_ranges.numel() > 0:
|
| 671 |
+
# Only sync for strict asserts (debug mode)
|
| 672 |
+
max_end = sel_ranges[..., 1].max().item() # GPU sync only in debug
|
| 673 |
+
assert max_end <= current_pos + 1, (
|
| 674 |
+
f"Selection range violates causality: max_end={max_end} > current_pos+1={current_pos + 1}. "
|
| 675 |
+
f"Selection must not access future tokens."
|
| 676 |
+
)
|
| 677 |
+
# Update selection stats and observability: distance summary per step
|
| 678 |
+
try:
|
| 679 |
+
# Update per-step selection stats (decode has S==1)
|
| 680 |
+
self._update_sel_stats_from_ranges(sel_ranges)
|
| 681 |
+
starts = sel_ranges[..., 0].to(torch.int64)
|
| 682 |
+
ends = sel_ranges[..., 1].to(torch.int64)
|
| 683 |
+
lengths = (ends - starts).clamp_min(0)
|
| 684 |
+
dist = (kv.K_sel.shape[2] - 1) - starts
|
| 685 |
+
log(
|
| 686 |
+
"decode.select",
|
| 687 |
+
n_ranges=int(sel_ranges.shape[2]),
|
| 688 |
+
mean_len=float(lengths.float().mean().item()) if lengths.numel() else 0.0,
|
| 689 |
+
max_len=int(lengths.max().item()) if lengths.numel() else 0,
|
| 690 |
+
mean_dist=float(dist.float().mean().item()) if dist.numel() else 0.0,
|
| 691 |
+
max_dist=int(dist.max().item()) if dist.numel() else 0,
|
| 692 |
+
)
|
| 693 |
+
except Exception as e:
|
| 694 |
+
log("warn.decode.select_log_fail", error=str(e))
|
| 695 |
+
Q_t = Q[:, t]
|
| 696 |
+
K_sel_t = kv.K_sel
|
| 697 |
+
V_sel_t = kv.V_sel
|
| 698 |
+
# Selection attention: prefer Triton if enabled; else packed; fallback to gather
|
| 699 |
+
force_parity = env["force_parity"]
|
| 700 |
+
use_sel_pack = env["use_sel_pack"] and not force_parity
|
| 701 |
+
use_triton_sel = env["use_triton_sel"] and not force_parity
|
| 702 |
+
use_cuda_sel = env["use_cuda_sel"] and not force_parity
|
| 703 |
+
force_sel_mask = env.get("force_sel_mask", False) and not force_parity
|
| 704 |
+
if force_sel_mask:
|
| 705 |
+
try:
|
| 706 |
+
O_sel_bt = grouped_selection_attention_masked(
|
| 707 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 708 |
+
)
|
| 709 |
+
O_sel = O_sel_bt[:, 0]
|
| 710 |
+
log("decode.sel.path", path="masked_forced")
|
| 711 |
+
except Exception as e:
|
| 712 |
+
self._fallback_counters["selection_mask_fails"] += 1
|
| 713 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 714 |
+
log("warn.masked_selection_forced_fallback",
|
| 715 |
+
error=str(e),
|
| 716 |
+
step=t,
|
| 717 |
+
Q_shape=list(Q_t.shape),
|
| 718 |
+
K_shape=list(K_sel_t.shape),
|
| 719 |
+
V_shape=list(V_sel_t.shape),
|
| 720 |
+
ranges_shape=list(sel_ranges.shape) if sel_ranges is not None else None,
|
| 721 |
+
total_fails=self._fallback_counters["selection_mask_fails"])
|
| 722 |
+
O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
|
| 723 |
+
elif use_triton_sel:
|
| 724 |
+
try:
|
| 725 |
+
from nsa.kernels.triton_sel_kernel import selection_attention_triton
|
| 726 |
+
|
| 727 |
+
O_sel_bt = selection_attention_triton(
|
| 728 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 729 |
+
)
|
| 730 |
+
O_sel = O_sel_bt[:, 0]
|
| 731 |
+
log("decode.sel.path", path="triton")
|
| 732 |
+
except Exception as e:
|
| 733 |
+
# M8: Fallback counter - Triton selection failed
|
| 734 |
+
self._fallback_counters["selection_triton_fails"] += 1
|
| 735 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 736 |
+
log(
|
| 737 |
+
"warn.triton_selection_fallback",
|
| 738 |
+
error=str(e),
|
| 739 |
+
step=t,
|
| 740 |
+
Q_shape=list(Q_t.shape),
|
| 741 |
+
K_shape=list(K_sel_t.shape),
|
| 742 |
+
V_shape=list(V_sel_t.shape),
|
| 743 |
+
ranges_shape=list(sel_ranges.shape) if sel_ranges is not None else None,
|
| 744 |
+
total_fails=self._fallback_counters["selection_triton_fails"],
|
| 745 |
+
)
|
| 746 |
+
# Fallback to packed SDPA
|
| 747 |
+
O_sel_bt = grouped_selection_attention_packed(
|
| 748 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 749 |
+
)
|
| 750 |
+
O_sel = O_sel_bt[:, 0]
|
| 751 |
+
elif use_cuda_sel:
|
| 752 |
+
try:
|
| 753 |
+
from nsa.kernels.cuda_sel_kernel import selection_attention_cuda
|
| 754 |
+
|
| 755 |
+
O_sel_bt = selection_attention_cuda(
|
| 756 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 757 |
+
)
|
| 758 |
+
O_sel = O_sel_bt[:, 0]
|
| 759 |
+
except Exception as e:
|
| 760 |
+
# M8: Fallback counter - CUDA selection failed
|
| 761 |
+
self._fallback_counters["selection_cuda_fails"] += 1
|
| 762 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 763 |
+
log(
|
| 764 |
+
"warn.cuda_selection_fallback",
|
| 765 |
+
error=str(e),
|
| 766 |
+
step=t,
|
| 767 |
+
Q_shape=list(Q_t.shape),
|
| 768 |
+
K_shape=list(K_sel_t.shape),
|
| 769 |
+
V_shape=list(V_sel_t.shape),
|
| 770 |
+
ranges_shape=list(sel_ranges.shape) if sel_ranges is not None else None,
|
| 771 |
+
total_fails=self._fallback_counters["selection_cuda_fails"],
|
| 772 |
+
)
|
| 773 |
+
# Fallback to packed SDPA
|
| 774 |
+
O_sel_bt = grouped_selection_attention_packed(
|
| 775 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 776 |
+
)
|
| 777 |
+
O_sel = O_sel_bt[:, 0]
|
| 778 |
+
elif use_sel_pack:
|
| 779 |
+
try:
|
| 780 |
+
O_sel_bt = grouped_selection_attention_packed(
|
| 781 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 782 |
+
)
|
| 783 |
+
O_sel = O_sel_bt[:, 0]
|
| 784 |
+
log("decode.sel.path", path="packed")
|
| 785 |
+
except Exception as e:
|
| 786 |
+
# M8: Fallback counter - Packed selection failed
|
| 787 |
+
self._fallback_counters["selection_pack_fails"] += 1
|
| 788 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 789 |
+
log(
|
| 790 |
+
"warn.packed_selection_fallback",
|
| 791 |
+
error=str(e),
|
| 792 |
+
step=t,
|
| 793 |
+
Q_shape=list(Q_t.shape),
|
| 794 |
+
K_shape=list(K_sel_t.shape),
|
| 795 |
+
V_shape=list(V_sel_t.shape),
|
| 796 |
+
ranges_shape=list(sel_ranges.shape) if sel_ranges is not None else None,
|
| 797 |
+
total_fails=self._fallback_counters["selection_pack_fails"],
|
| 798 |
+
)
|
| 799 |
+
# Fallback to gather SDPA
|
| 800 |
+
O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
|
| 801 |
+
elif self._env_cache.get("use_sel_mask", False) and not force_parity:
|
| 802 |
+
try:
|
| 803 |
+
O_sel_bt = grouped_selection_attention_masked(
|
| 804 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 805 |
+
)
|
| 806 |
+
O_sel = O_sel_bt[:, 0]
|
| 807 |
+
log("decode.sel.path", path="masked")
|
| 808 |
+
except Exception as e:
|
| 809 |
+
# M8: Fallback counter - Masked selection failed
|
| 810 |
+
self._fallback_counters["selection_mask_fails"] += 1
|
| 811 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 812 |
+
log(
|
| 813 |
+
"warn.masked_selection_fallback",
|
| 814 |
+
error=str(e),
|
| 815 |
+
step=t,
|
| 816 |
+
Q_shape=list(Q_t.shape),
|
| 817 |
+
K_shape=list(K_sel_t.shape),
|
| 818 |
+
V_shape=list(V_sel_t.shape),
|
| 819 |
+
ranges_shape=list(sel_ranges.shape) if sel_ranges is not None else None,
|
| 820 |
+
total_fails=self._fallback_counters["selection_mask_fails"],
|
| 821 |
+
)
|
| 822 |
+
# Fallback to gather SDPA
|
| 823 |
+
O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
|
| 824 |
+
else:
|
| 825 |
+
O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
|
| 826 |
+
win_len = min(self.w, kv.K_win.shape[2])
|
| 827 |
+
|
| 828 |
+
# M8: Assert causal masking - sliding window bounds in decode
|
| 829 |
+
total_tokens = kv.K_win.shape[2]
|
| 830 |
+
start_idx = total_tokens - win_len
|
| 831 |
+
end_idx = total_tokens
|
| 832 |
+
assert start_idx >= 0, (
|
| 833 |
+
f"Sliding window start index negative: start_idx={start_idx}, "
|
| 834 |
+
f"total_tokens={total_tokens}, win_len={win_len}"
|
| 835 |
+
)
|
| 836 |
+
assert end_idx <= total_tokens, (
|
| 837 |
+
f"Sliding window end exceeds cache: end_idx={end_idx} > total_tokens={total_tokens}"
|
| 838 |
+
)
|
| 839 |
+
assert win_len <= self.w, (
|
| 840 |
+
f"Window length exceeds max: win_len={win_len} > self.w={self.w}"
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
K_w = kv.K_win[:, :, start_idx:end_idx, :]
|
| 844 |
+
V_w = kv.V_win[:, :, start_idx:end_idx, :]
|
| 845 |
+
use_flash = (
|
| 846 |
+
env["fa2_all_eff"] or env["fa2_win_eff"] or env["fa2_cmp_eff"]
|
| 847 |
+
) and not force_parity
|
| 848 |
+
if use_flash and (env["fa2_all_eff"] or env["fa2_win_eff"]):
|
| 849 |
+
try:
|
| 850 |
+
O_win = sliding_window_attention_fa2_decode(Q_t, kv.K_win, kv.V_win, self.w)
|
| 851 |
+
except Exception as e:
|
| 852 |
+
# M8: Fallback counter - Sliding FA2 failed
|
| 853 |
+
self._fallback_counters["sliding_fa2_fails"] += 1
|
| 854 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 855 |
+
log(
|
| 856 |
+
"warn.sliding_fa2_fallback",
|
| 857 |
+
error=str(e),
|
| 858 |
+
total_fails=self._fallback_counters["sliding_fa2_fails"],
|
| 859 |
+
)
|
| 860 |
+
# Fallback to standard attention
|
| 861 |
+
O_win = attention_bgh(
|
| 862 |
+
Q_t.contiguous(), K_w.contiguous(), V_w.contiguous(), causal=True
|
| 863 |
+
)
|
| 864 |
+
else:
|
| 865 |
+
O_win = attention_bgh(
|
| 866 |
+
Q_t.contiguous(), K_w.contiguous(), V_w.contiguous(), causal=True
|
| 867 |
+
)
|
| 868 |
+
S_cmp_t = kv.K_cmp.shape[2]
|
| 869 |
+
|
| 870 |
+
# M8: Assert causal masking - compressed bounds in decode
|
| 871 |
+
assert S_cmp_t >= 0, f"Compressed cache size negative: S_cmp_t={S_cmp_t}"
|
| 872 |
+
assert S_cmp_t <= kv.K_cmp.shape[2], (
|
| 873 |
+
f"Compressed range exceeds cache: S_cmp_t={S_cmp_t} > cache_size={kv.K_cmp.shape[2]}"
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
if use_flash and (env["fa2_all_eff"] or env["fa2_cmp_eff"]):
|
| 877 |
+
try:
|
| 878 |
+
O_cmp = compressed_attention_fa2_decode(Q_t, kv.K_cmp, kv.V_cmp, S_cmp_t)
|
| 879 |
+
except Exception as e:
|
| 880 |
+
# M8: Fallback counter - Compressed FA2 failed
|
| 881 |
+
self._fallback_counters["compressed_fa2_fails"] += 1
|
| 882 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 883 |
+
log(
|
| 884 |
+
"warn.compressed_fa2_fallback",
|
| 885 |
+
error=str(e),
|
| 886 |
+
total_fails=self._fallback_counters["compressed_fa2_fails"],
|
| 887 |
+
)
|
| 888 |
+
# Fallback to standard attention
|
| 889 |
+
O_cmp = attention_bgh(
|
| 890 |
+
Q_t.contiguous(),
|
| 891 |
+
kv.K_cmp[:, :, :S_cmp_t, :].contiguous(),
|
| 892 |
+
kv.V_cmp[:, :, :S_cmp_t, :].contiguous(),
|
| 893 |
+
causal=True,
|
| 894 |
+
)
|
| 895 |
+
else:
|
| 896 |
+
O_cmp = attention_bgh(
|
| 897 |
+
Q_t.contiguous(),
|
| 898 |
+
kv.K_cmp[:, :, :S_cmp_t, :].contiguous(),
|
| 899 |
+
kv.V_cmp[:, :, :S_cmp_t, :].contiguous(),
|
| 900 |
+
causal=True,
|
| 901 |
+
)
|
| 902 |
+
# Preserve dtype for gate input
|
| 903 |
+
q_gp = Q_t.mean(dim=2, dtype=Q_t.dtype)
|
| 904 |
+
if self._env_cache.get("gate_compile", False):
|
| 905 |
+
try:
|
| 906 |
+
fused = self._gate_fused_bg
|
| 907 |
+
if fused is None:
|
| 908 |
+
fused = _fused_gate_combine_bg
|
| 909 |
+
try:
|
| 910 |
+
fused = torch.compile(fused, mode="reduce-overhead") # type: ignore[attr-defined]
|
| 911 |
+
except Exception:
|
| 912 |
+
pass
|
| 913 |
+
self._gate_fused_bg = fused
|
| 914 |
+
O = fused(
|
| 915 |
+
q_gp,
|
| 916 |
+
O_cmp,
|
| 917 |
+
O_sel,
|
| 918 |
+
O_win,
|
| 919 |
+
self.gate.fc1.weight,
|
| 920 |
+
self.gate.fc1.bias,
|
| 921 |
+
self.gate.fc2.weight,
|
| 922 |
+
self.gate.fc2.bias,
|
| 923 |
+
float(self.gate_temp),
|
| 924 |
+
)
|
| 925 |
+
except Exception:
|
| 926 |
+
gates = self.gate(q_gp, tau=self.gate_temp)
|
| 927 |
+
if self._env_cache.get("stopgrad_gates", False):
|
| 928 |
+
gates = gates.detach()
|
| 929 |
+
self._update_gate_stats(gates)
|
| 930 |
+
try:
|
| 931 |
+
log(
|
| 932 |
+
"decode.gates",
|
| 933 |
+
mean=gates.mean(dim=(-1, -2)).tolist()
|
| 934 |
+
if gates.dim() >= 2
|
| 935 |
+
else gates.mean().item(),
|
| 936 |
+
std=gates.std(dim=(-1, -2)).tolist()
|
| 937 |
+
if gates.dim() >= 2
|
| 938 |
+
else gates.std().item(),
|
| 939 |
+
)
|
| 940 |
+
except Exception as e:
|
| 941 |
+
log("warn.decode.gate_log_fail", error=str(e))
|
| 942 |
+
w_cmp = gates[..., 0:1].unsqueeze(-1)
|
| 943 |
+
w_sel = gates[..., 1:2].unsqueeze(-1)
|
| 944 |
+
w_win = gates[..., 2:3].unsqueeze(-1)
|
| 945 |
+
O = w_cmp * O_cmp + w_sel * O_sel + w_win * O_win
|
| 946 |
+
else:
|
| 947 |
+
gates = self.gate(q_gp, tau=self.gate_temp)
|
| 948 |
+
if self._env_cache.get("stopgrad_gates", False):
|
| 949 |
+
gates = gates.detach()
|
| 950 |
+
self._update_gate_stats(gates)
|
| 951 |
+
try:
|
| 952 |
+
log(
|
| 953 |
+
"decode.gates",
|
| 954 |
+
mean=gates.mean(dim=(-1, -2)).tolist()
|
| 955 |
+
if gates.dim() >= 2
|
| 956 |
+
else gates.mean().item(),
|
| 957 |
+
std=gates.std(dim=(-1, -2)).tolist()
|
| 958 |
+
if gates.dim() >= 2
|
| 959 |
+
else gates.std().item(),
|
| 960 |
+
)
|
| 961 |
+
except Exception as e:
|
| 962 |
+
log("warn.decode.gate_log_fail", error=str(e))
|
| 963 |
+
w_cmp = gates[..., 0:1].unsqueeze(-1)
|
| 964 |
+
w_sel = gates[..., 1:2].unsqueeze(-1)
|
| 965 |
+
w_win = gates[..., 2:3].unsqueeze(-1)
|
| 966 |
+
O = w_cmp * O_cmp + w_sel * O_sel + w_win * O_win
|
| 967 |
+
O_heads = O.reshape(B, self.n_heads, self.d_v)
|
| 968 |
+
out_t = self.out(O_heads.reshape(B, 1, -1))
|
| 969 |
+
outs.append(out_t)
|
| 970 |
+
out = torch.cat(outs, dim=1)
|
| 971 |
+
return out, kv
|
| 972 |
+
|
| 973 |
+
def _forward_prefill_batched(self, x: torch.Tensor, kv: NSA_KV) -> tuple[torch.Tensor, NSA_KV]:
|
| 974 |
+
"""
|
| 975 |
+
Vectorized prefill path.
|
| 976 |
+
|
| 977 |
+
Steps:
|
| 978 |
+
- Projections with RoPE(Q); RoPE applied to K before ϕ for compressed branch
|
| 979 |
+
- Cache updates for selection/window/compressed
|
| 980 |
+
- Batched p_cmp → p_slc → p_grp; top‑n ranges for all t
|
| 981 |
+
- Branch attentions (masked/packed per env flags), gating, projection
|
| 982 |
+
"""
|
| 983 |
+
B, S, _ = x.shape
|
| 984 |
+
# Projections
|
| 985 |
+
_nvtx = self._env_cache.get("nvtx", False)
|
| 986 |
+
if _nvtx:
|
| 987 |
+
try:
|
| 988 |
+
import torch as _t
|
| 989 |
+
|
| 990 |
+
_t.cuda.nvtx.range_push("projections+rope")
|
| 991 |
+
except Exception:
|
| 992 |
+
_nvtx = False
|
| 993 |
+
Q_lin = self._shape_q(self.W_Q(x), B, S) # [B,S,G,h,Dk]
|
| 994 |
+
assert Q_lin.shape[:2] == (B, S)
|
| 995 |
+
# Apply RoPE to Q
|
| 996 |
+
pos = torch.arange(S, device=x.device)
|
| 997 |
+
Q = apply_rope(
|
| 998 |
+
Q_lin.view(B, S, self.n_heads, self.d_k).reshape(B, S, self.n_heads * self.d_k),
|
| 999 |
+
pos,
|
| 1000 |
+
scale=getattr(self, "rope_scale", 1.0),
|
| 1001 |
+
)
|
| 1002 |
+
Q = Q.view(B, S, self.n_heads, self.d_k).view(
|
| 1003 |
+
B, S, self.n_kv_groups, self.h_per_group, self.d_k
|
| 1004 |
+
)
|
| 1005 |
+
# K/V projections per branch
|
| 1006 |
+
K_sel = self._shape_kv(self.W_K_sel(x), B, S)
|
| 1007 |
+
V_sel = self._shape_kv(self.W_V_sel(x), B, S)
|
| 1008 |
+
K_win = self._shape_kv(self.W_K_win(x), B, S)
|
| 1009 |
+
V_win = self._shape_kv(self.W_V_win(x), B, S)
|
| 1010 |
+
K_cmp_raw = self._shape_kv(self.W_K_cmp(x), B, S)
|
| 1011 |
+
V_cmp_raw = self._shape_kv(self.W_V_cmp(x), B, S)
|
| 1012 |
+
G = self.n_kv_groups
|
| 1013 |
+
assert K_sel.shape[:3] == (B, G, S) and V_sel.shape[:3] == (B, G, S)
|
| 1014 |
+
assert K_win.shape[:3] == (B, G, S) and V_win.shape[:3] == (B, G, S)
|
| 1015 |
+
assert K_cmp_raw.shape[:3] == (B, G, S) and V_cmp_raw.shape[:3] == (B, G, S)
|
| 1016 |
+
|
| 1017 |
+
# Apply RoPE to per-branch K tensors (Q already has RoPE applied)
|
| 1018 |
+
pos_k = torch.arange(S, device=x.device)
|
| 1019 |
+
K_sel = apply_rope(K_sel, pos_k, scale=getattr(self, "rope_scale", 1.0))
|
| 1020 |
+
K_win = apply_rope(K_win, pos_k, scale=getattr(self, "rope_scale", 1.0))
|
| 1021 |
+
if _nvtx:
|
| 1022 |
+
try:
|
| 1023 |
+
_t.cuda.nvtx.range_pop()
|
| 1024 |
+
except Exception:
|
| 1025 |
+
pass
|
| 1026 |
+
|
| 1027 |
+
# Update caches (prefill uses full sequence projections)
|
| 1028 |
+
kv.update_selection_raw(K_sel, V_sel)
|
| 1029 |
+
# Build/refresh meta for selection and compressed mapping
|
| 1030 |
+
kv.meta = build_block_meta(
|
| 1031 |
+
seq_len=S, l=self.l, d=self.d, l_sel=self.l_sel, n_sel=self.n_sel, w=self.w
|
| 1032 |
+
)
|
| 1033 |
+
kv.update_window(K_win, V_win, self.w)
|
| 1034 |
+
if self.phi_type == "mlp":
|
| 1035 |
+
K_cmp, V_cmp = self._phi_apply_seq(
|
| 1036 |
+
K_cmp_raw, V_cmp_raw, pos=torch.arange(S, device=x.device)
|
| 1037 |
+
)
|
| 1038 |
+
else:
|
| 1039 |
+
K_cmp, V_cmp = avg_pool_phi_rope_kv(
|
| 1040 |
+
K_cmp_raw, V_cmp_raw, self.l, self.d, pos=torch.arange(S, device=x.device)
|
| 1041 |
+
)
|
| 1042 |
+
kv.update_compressed(K_cmp, V_cmp, self.l, self.d)
|
| 1043 |
+
|
| 1044 |
+
# One-time SDPA backend audit (opt-in via env)
|
| 1045 |
+
try:
|
| 1046 |
+
if (not self._sdpa_audited) and os.getenv("NSA_SDPA_AUDIT", "0").lower() in (
|
| 1047 |
+
"1",
|
| 1048 |
+
"true",
|
| 1049 |
+
"yes",
|
| 1050 |
+
):
|
| 1051 |
+
self._audit_sdpa_backends_once(
|
| 1052 |
+
Q[:, :1],
|
| 1053 |
+
K_sel[:, :, : max(1, S // 8), :],
|
| 1054 |
+
V_sel[:, :, : max(1, S // 8), :],
|
| 1055 |
+
K_win[:, :, : max(1, S // 8), :],
|
| 1056 |
+
V_win[:, :, : max(1, S // 8), :],
|
| 1057 |
+
)
|
| 1058 |
+
except Exception:
|
| 1059 |
+
pass
|
| 1060 |
+
|
| 1061 |
+
# Selection scores (batched)
|
| 1062 |
+
scale = 1.0 / (self.d_k**0.5)
|
| 1063 |
+
if _nvtx:
|
| 1064 |
+
try:
|
| 1065 |
+
_t.cuda.nvtx.range_push("pcmp_all")
|
| 1066 |
+
except Exception:
|
| 1067 |
+
pass
|
| 1068 |
+
p_cmp_all = compute_pcmp_all(Q, kv.K_cmp, scale) # [B,S,G,h,S_cmp]
|
| 1069 |
+
if _nvtx:
|
| 1070 |
+
try:
|
| 1071 |
+
_t.cuda.nvtx.range_pop()
|
| 1072 |
+
_t.cuda.nvtx.range_push("map_pcmp_to_pslc")
|
| 1073 |
+
except Exception:
|
| 1074 |
+
pass
|
| 1075 |
+
p_slc_all = map_pcmp_to_pslc_batched(p_cmp_all, kv.meta) # [B,S,G,h,S_sel]
|
| 1076 |
+
|
| 1077 |
+
# M8: Optional Eq.9 verification in batched prefill
|
| 1078 |
+
if self._env_cache.get("verify_eq9", False):
|
| 1079 |
+
is_equiv, details = verify_mapping_equivalence(p_cmp_all, kv.meta)
|
| 1080 |
+
if not is_equiv:
|
| 1081 |
+
log(
|
| 1082 |
+
"error.eq9_verification_failed_prefill",
|
| 1083 |
+
msg="Eq.9 mapping verification failed in batched prefill",
|
| 1084 |
+
**details,
|
| 1085 |
+
)
|
| 1086 |
+
p_grp_all = p_slc_all.sum(dim=3) # [B,S,G,S_sel]
|
| 1087 |
+
log(
|
| 1088 |
+
"prefill.scores",
|
| 1089 |
+
B=B,
|
| 1090 |
+
S=S,
|
| 1091 |
+
S_cmp=int(kv.K_cmp.shape[2]),
|
| 1092 |
+
S_sel=int(kv.meta.sel_starts.numel()),
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
# Batched top‑n → ranges for all positions
|
| 1096 |
+
if _nvtx:
|
| 1097 |
+
try:
|
| 1098 |
+
_t.cuda.nvtx.range_push("topk+ranges")
|
| 1099 |
+
except Exception:
|
| 1100 |
+
pass
|
| 1101 |
+
sel_ranges_all = select_topn_ranges_batched(
|
| 1102 |
+
p_grp_all, kv.meta, self.n_sel, S, True, 2
|
| 1103 |
+
) # [B,S,G,n,2]
|
| 1104 |
+
if _nvtx:
|
| 1105 |
+
try:
|
| 1106 |
+
_t.cuda.nvtx.range_pop()
|
| 1107 |
+
_t.cuda.nvtx.range_push("branch_attn+gate")
|
| 1108 |
+
except Exception:
|
| 1109 |
+
pass
|
| 1110 |
+
# Update selection statistics for this prefill batch
|
| 1111 |
+
self._update_sel_stats_from_ranges(sel_ranges_all)
|
| 1112 |
+
if _nvtx:
|
| 1113 |
+
try:
|
| 1114 |
+
_t.cuda.nvtx.range_pop()
|
| 1115 |
+
except Exception:
|
| 1116 |
+
pass
|
| 1117 |
+
|
| 1118 |
+
# M8: Assert causal masking for batched selection (GPU-sync gated)
|
| 1119 |
+
strict_asserts = self._env_cache.get("strict_asserts", False)
|
| 1120 |
+
if strict_asserts and sel_ranges_all.numel() > 0:
|
| 1121 |
+
for t in range(S):
|
| 1122 |
+
t_ranges = sel_ranges_all[:, t] # [B,G,n,2]
|
| 1123 |
+
if t_ranges.numel() > 0:
|
| 1124 |
+
max_end = t_ranges[..., 1].max().item()
|
| 1125 |
+
assert max_end <= t + 1, (
|
| 1126 |
+
f"Batched selection violates causality at t={t}: max_end={max_end} > t+1={t + 1}. "
|
| 1127 |
+
f"Selection ranges cannot access future tokens."
|
| 1128 |
+
)
|
| 1129 |
+
log("prefill.select", n_sel=self.n_sel, l_sel=self.l_sel, ranges=sel_ranges_all)
|
| 1130 |
+
|
| 1131 |
+
# Branch attentions in parallel (parity-first for cmp/win, with optional masked SDPA gates)
|
| 1132 |
+
force_parity = self._env_cache.get("force_parity", False)
|
| 1133 |
+
fa2_all = self._env_cache.get("fa2_all_eff", False)
|
| 1134 |
+
fa2_win = self._env_cache.get("fa2_win_eff", False)
|
| 1135 |
+
fa2_cmp = self._env_cache.get("fa2_cmp_eff", False)
|
| 1136 |
+
use_cmp_mask = self._env_cache.get("use_cmp_mask", True) and not force_parity
|
| 1137 |
+
if (fa2_all or fa2_cmp) and not force_parity:
|
| 1138 |
+
try:
|
| 1139 |
+
O_cmp = compressed_attention_fa2(Q, kv.K_cmp, kv.V_cmp, self.l, self.d)
|
| 1140 |
+
except Exception as e:
|
| 1141 |
+
# M8: Fallback counter - Compressed FA2 failed in prefill
|
| 1142 |
+
self._fallback_counters["compressed_fa2_fails"] += 1
|
| 1143 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1144 |
+
log(
|
| 1145 |
+
"warn.compressed_fa2_prefill_fallback",
|
| 1146 |
+
error=str(e),
|
| 1147 |
+
total_fails=self._fallback_counters["compressed_fa2_fails"],
|
| 1148 |
+
)
|
| 1149 |
+
# Fallback to masked SDPA
|
| 1150 |
+
from nsa.core.attention_kernels import batched_causal_attention_compressed_masked
|
| 1151 |
+
|
| 1152 |
+
O_cmp = batched_causal_attention_compressed_masked(
|
| 1153 |
+
Q, kv.K_cmp, kv.V_cmp, self.l, self.d
|
| 1154 |
+
)
|
| 1155 |
+
elif use_cmp_mask:
|
| 1156 |
+
from nsa.core.attention_kernels import batched_causal_attention_compressed_masked
|
| 1157 |
+
|
| 1158 |
+
O_cmp = batched_causal_attention_compressed_masked(
|
| 1159 |
+
Q, kv.K_cmp, kv.V_cmp, self.l, self.d
|
| 1160 |
+
)
|
| 1161 |
+
else:
|
| 1162 |
+
# Compressed per-t using the same kernel as sequential
|
| 1163 |
+
O_cmp = torch.zeros(
|
| 1164 |
+
(B, S, self.n_kv_groups, self.h_per_group, self.d_v),
|
| 1165 |
+
device=x.device,
|
| 1166 |
+
dtype=V_cmp.dtype,
|
| 1167 |
+
)
|
| 1168 |
+
S_cmp_full = kv.K_cmp.shape[2]
|
| 1169 |
+
for t in range(S):
|
| 1170 |
+
L = 0 if (t + 1) < self.l else min(((t + 1 - self.l) // self.d) + 1, S_cmp_full)
|
| 1171 |
+
|
| 1172 |
+
# M8: Assert causal masking - compressed tokens must respect position bounds
|
| 1173 |
+
if L > 0:
|
| 1174 |
+
# Check that compressed range doesn't exceed causal bounds
|
| 1175 |
+
assert L <= S_cmp_full, (
|
| 1176 |
+
f"Compressed range exceeds cache: L={L} > S_cmp_full={S_cmp_full} at t={t}"
|
| 1177 |
+
)
|
| 1178 |
+
# Verify causal constraint: at position t, can only see compressed tokens
|
| 1179 |
+
# that represent original positions up to t
|
| 1180 |
+
max_allowed_L = ((t + 1 - self.l) // self.d) + 1 if (t + 1) >= self.l else 0
|
| 1181 |
+
assert L <= max_allowed_L, (
|
| 1182 |
+
f"Compressed range violates causality: L={L} > max_allowed_L={max_allowed_L} "
|
| 1183 |
+
f"at t={t}. Compressed tokens represent future positions."
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
q_t = Q[:, t].contiguous()
|
| 1187 |
+
k_t = kv.K_cmp[:, :, :L, :].contiguous()
|
| 1188 |
+
v_t = kv.V_cmp[:, :, :L, :].contiguous()
|
| 1189 |
+
O_cmp[:, t] = attention_bgh(q_t, k_t, v_t, causal=True)
|
| 1190 |
+
# Strict finite check and fallback
|
| 1191 |
+
if strict_asserts and not torch.isfinite(O_cmp).all():
|
| 1192 |
+
from nsa.core.attention_kernels import batched_causal_attention_compressed_masked
|
| 1193 |
+
|
| 1194 |
+
log("warn.prefill_cmp_nonfinite_fallback")
|
| 1195 |
+
O_cmp = batched_causal_attention_compressed_masked(
|
| 1196 |
+
Q, kv.K_cmp, kv.V_cmp, self.l, self.d
|
| 1197 |
+
)
|
| 1198 |
+
log("prefill.cmp", O_cmp=O_cmp)
|
| 1199 |
+
|
| 1200 |
+
# Selected ranges attention (prefer Triton if enabled; else packed/gather)
|
| 1201 |
+
use_sel_pack = self._env_cache.get("use_sel_pack", True) and not force_parity
|
| 1202 |
+
use_sel_varlen = self._env_cache.get("use_sel_varlen", False) and not force_parity
|
| 1203 |
+
use_triton_sel = (
|
| 1204 |
+
self._env_cache.get("use_triton_sel", False) or self.use_triton_sel and not force_parity
|
| 1205 |
+
)
|
| 1206 |
+
force_sel_mask = self._env_cache.get("force_sel_mask", False) and not force_parity
|
| 1207 |
+
if force_sel_mask:
|
| 1208 |
+
try:
|
| 1209 |
+
O_sel = grouped_selection_attention_masked(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1210 |
+
log("prefill.sel.path", path="masked_forced")
|
| 1211 |
+
except Exception as e:
|
| 1212 |
+
# Fallback to gather SDPA
|
| 1213 |
+
self._fallback_counters["selection_mask_fails"] += 1
|
| 1214 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1215 |
+
log("warn.masked_selection_prefill_forced_fallback",
|
| 1216 |
+
error=str(e),
|
| 1217 |
+
Q_shape=list(Q.shape) if hasattr(Q, 'shape') else list(Q_t.shape),
|
| 1218 |
+
K_shape=list(kv.K_sel.shape) if hasattr(kv, 'K_sel') else list(K_sel_t.shape),
|
| 1219 |
+
V_shape=list(kv.V_sel.shape) if hasattr(kv, 'V_sel') else list(V_sel_t.shape),
|
| 1220 |
+
ranges_shape=list(sel_ranges_all.shape) if 'sel_ranges_all' in locals() else list(sel_ranges.shape) if sel_ranges is not None else None,
|
| 1221 |
+
total_fails=self._fallback_counters["selection_mask_fails"])
|
| 1222 |
+
O_sel = grouped_selection_attention(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1223 |
+
elif use_triton_sel:
|
| 1224 |
+
try:
|
| 1225 |
+
from nsa.kernels.triton_sel_kernel import selection_attention_triton
|
| 1226 |
+
|
| 1227 |
+
O_sel = selection_attention_triton(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1228 |
+
log("prefill.sel.path", path="triton")
|
| 1229 |
+
except Exception as e:
|
| 1230 |
+
# M8: Fallback counter - Triton selection failed in prefill
|
| 1231 |
+
self._fallback_counters["selection_triton_fails"] += 1
|
| 1232 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1233 |
+
log(
|
| 1234 |
+
"warn.triton_selection_prefill_fallback",
|
| 1235 |
+
error=str(e),
|
| 1236 |
+
Q_shape=list(Q.shape) if hasattr(Q, 'shape') else list(Q_t.shape),
|
| 1237 |
+
K_shape=list(kv.K_sel.shape) if hasattr(kv, 'K_sel') else list(K_sel_t.shape),
|
| 1238 |
+
V_shape=list(kv.V_sel.shape) if hasattr(kv, 'V_sel') else list(V_sel_t.shape),
|
| 1239 |
+
ranges_shape=list(sel_ranges_all.shape) if 'sel_ranges_all' in locals() else list(sel_ranges.shape) if 'sel_ranges' in locals() and sel_ranges is not None else None,
|
| 1240 |
+
total_fails=self._fallback_counters["selection_triton_fails"],
|
| 1241 |
+
)
|
| 1242 |
+
# Fallback to packed SDPA
|
| 1243 |
+
O_sel = grouped_selection_attention_packed(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1244 |
+
elif use_sel_varlen:
|
| 1245 |
+
try:
|
| 1246 |
+
from nsa.core.attention_kernels import selection_attention_varlen_all
|
| 1247 |
+
|
| 1248 |
+
O_sel = selection_attention_varlen_all(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1249 |
+
log("prefill.sel.path", path="varlen")
|
| 1250 |
+
except Exception as e:
|
| 1251 |
+
# Fallback counter reuse for selection pack failures
|
| 1252 |
+
self._fallback_counters["selection_pack_fails"] += 1
|
| 1253 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1254 |
+
log(
|
| 1255 |
+
"warn.selection_varlen_prefill_fallback",
|
| 1256 |
+
error=str(e),
|
| 1257 |
+
total_fails=self._fallback_counters["selection_pack_fails"],
|
| 1258 |
+
)
|
| 1259 |
+
# Fallback to packed SDPA
|
| 1260 |
+
O_sel = grouped_selection_attention_packed(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1261 |
+
log("prefill.sel.path", path="packed")
|
| 1262 |
+
elif use_sel_pack:
|
| 1263 |
+
try:
|
| 1264 |
+
O_sel = grouped_selection_attention_packed(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1265 |
+
except Exception as e:
|
| 1266 |
+
# M8: Fallback counter - Packed selection failed in prefill
|
| 1267 |
+
self._fallback_counters["selection_pack_fails"] += 1
|
| 1268 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1269 |
+
log(
|
| 1270 |
+
"warn.packed_selection_prefill_fallback",
|
| 1271 |
+
error=str(e),
|
| 1272 |
+
total_fails=self._fallback_counters["selection_pack_fails"],
|
| 1273 |
+
)
|
| 1274 |
+
# Fallback to gather SDPA
|
| 1275 |
+
O_sel = grouped_selection_attention(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1276 |
+
elif self._env_cache.get("use_sel_mask", False):
|
| 1277 |
+
try:
|
| 1278 |
+
O_sel = grouped_selection_attention_masked(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1279 |
+
log("prefill.sel.path", path="masked")
|
| 1280 |
+
except Exception as e:
|
| 1281 |
+
# M8: Fallback counter - Masked selection failed in prefill
|
| 1282 |
+
self._fallback_counters["selection_mask_fails"] += 1
|
| 1283 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1284 |
+
log(
|
| 1285 |
+
"warn.masked_selection_prefill_fallback",
|
| 1286 |
+
error=str(e),
|
| 1287 |
+
total_fails=self._fallback_counters["selection_mask_fails"],
|
| 1288 |
+
)
|
| 1289 |
+
# Fallback to gather SDPA
|
| 1290 |
+
O_sel = grouped_selection_attention(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1291 |
+
else:
|
| 1292 |
+
O_sel = grouped_selection_attention(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1293 |
+
log("prefill.sel.path", path="gather")
|
| 1294 |
+
if strict_asserts and not torch.isfinite(O_sel).all():
|
| 1295 |
+
log("warn.prefill_sel_nonfinite_fallback")
|
| 1296 |
+
O_sel = grouped_selection_attention(Q, kv.K_sel, kv.V_sel, sel_ranges_all)
|
| 1297 |
+
log("prefill.sel", O_sel=O_sel)
|
| 1298 |
+
|
| 1299 |
+
use_win_mask = self._env_cache.get("use_win_mask", True) and not force_parity
|
| 1300 |
+
if (fa2_all or fa2_win) and not force_parity:
|
| 1301 |
+
try:
|
| 1302 |
+
O_win = sliding_window_attention_fa2(Q, K_win, V_win, self.w)
|
| 1303 |
+
except Exception as e:
|
| 1304 |
+
# M8: Fallback counter - Sliding FA2 failed in prefill
|
| 1305 |
+
self._fallback_counters["sliding_fa2_fails"] += 1
|
| 1306 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1307 |
+
log(
|
| 1308 |
+
"warn.sliding_fa2_prefill_fallback",
|
| 1309 |
+
error=str(e),
|
| 1310 |
+
total_fails=self._fallback_counters["sliding_fa2_fails"],
|
| 1311 |
+
)
|
| 1312 |
+
# Fallback to masked SDPA
|
| 1313 |
+
from nsa.core.attention_kernels import sliding_window_attention
|
| 1314 |
+
|
| 1315 |
+
O_win = sliding_window_attention(Q, K_win, V_win, self.w)
|
| 1316 |
+
elif use_win_mask:
|
| 1317 |
+
from nsa.core.attention_kernels import sliding_window_attention
|
| 1318 |
+
|
| 1319 |
+
O_win = sliding_window_attention(Q, K_win, V_win, self.w)
|
| 1320 |
+
else:
|
| 1321 |
+
# Sliding per-t using the same kernel as sequential
|
| 1322 |
+
O_win = torch.zeros(
|
| 1323 |
+
(B, S, self.n_kv_groups, self.h_per_group, self.d_v),
|
| 1324 |
+
device=x.device,
|
| 1325 |
+
dtype=V_win.dtype,
|
| 1326 |
+
)
|
| 1327 |
+
for t in range(S):
|
| 1328 |
+
end = t + 1
|
| 1329 |
+
start = max(0, end - self.w)
|
| 1330 |
+
|
| 1331 |
+
# M8: Assert causal masking - sliding window must not exceed current position
|
| 1332 |
+
assert end <= t + 1, (
|
| 1333 |
+
f"Sliding window violates causality: end={end} > t+1={t + 1} at position t={t}. "
|
| 1334 |
+
f"This indicates window is accessing future tokens."
|
| 1335 |
+
)
|
| 1336 |
+
assert start <= end, (
|
| 1337 |
+
f"Sliding window has invalid range: start={start} > end={end} at position t={t}."
|
| 1338 |
+
)
|
| 1339 |
+
|
| 1340 |
+
q_t = Q[:, t].contiguous()
|
| 1341 |
+
k_t = K_win[:, :, start:end, :].contiguous()
|
| 1342 |
+
v_t = V_win[:, :, start:end, :].contiguous()
|
| 1343 |
+
O_win[:, t] = attention_bgh(q_t, k_t, v_t, causal=True)
|
| 1344 |
+
if strict_asserts and not torch.isfinite(O_win).all():
|
| 1345 |
+
from nsa.core.attention_kernels import sliding_window_attention
|
| 1346 |
+
|
| 1347 |
+
log("warn.prefill_win_nonfinite_fallback")
|
| 1348 |
+
O_win = sliding_window_attention(Q, K_win, V_win, self.w)
|
| 1349 |
+
log("prefill.win", O_win=O_win)
|
| 1350 |
+
|
| 1351 |
+
# Gates and combine
|
| 1352 |
+
q_gp = Q.mean(dim=3) # [B,S,G,Dk]
|
| 1353 |
+
if self._env_cache.get("gate_compile", False):
|
| 1354 |
+
try:
|
| 1355 |
+
fused = self._gate_fused_bsg
|
| 1356 |
+
if fused is None:
|
| 1357 |
+
fused = _fused_gate_combine_bsg
|
| 1358 |
+
try:
|
| 1359 |
+
fused = torch.compile(fused, mode="reduce-overhead") # type: ignore[attr-defined]
|
| 1360 |
+
except Exception:
|
| 1361 |
+
pass
|
| 1362 |
+
self._gate_fused_bsg = fused
|
| 1363 |
+
O = fused(
|
| 1364 |
+
q_gp,
|
| 1365 |
+
O_cmp,
|
| 1366 |
+
O_sel,
|
| 1367 |
+
O_win,
|
| 1368 |
+
self.gate.fc1.weight,
|
| 1369 |
+
self.gate.fc1.bias,
|
| 1370 |
+
self.gate.fc2.weight,
|
| 1371 |
+
self.gate.fc2.bias,
|
| 1372 |
+
float(self.gate_temp),
|
| 1373 |
+
)
|
| 1374 |
+
except Exception:
|
| 1375 |
+
gates = self.gate(q_gp.reshape(B * S * self.n_kv_groups, self.d_k), tau=self.gate_temp)
|
| 1376 |
+
if self._env_cache.get("stopgrad_gates", False):
|
| 1377 |
+
gates = gates.detach()
|
| 1378 |
+
gates = gates.view(B, S, self.n_kv_groups, 3) # [B,S,G,3]
|
| 1379 |
+
self._update_gate_stats(gates)
|
| 1380 |
+
w_cmp = gates[..., 0:1].unsqueeze(3)
|
| 1381 |
+
w_sel = gates[..., 1:2].unsqueeze(3)
|
| 1382 |
+
w_win = gates[..., 2:3].unsqueeze(3)
|
| 1383 |
+
O = w_cmp * O_cmp + w_sel * O_sel + w_win * O_win # [B,S,G,h,Dv]
|
| 1384 |
+
else:
|
| 1385 |
+
gates = self.gate(q_gp.reshape(B * S * self.n_kv_groups, self.d_k), tau=self.gate_temp)
|
| 1386 |
+
if self._env_cache.get("stopgrad_gates", False):
|
| 1387 |
+
gates = gates.detach()
|
| 1388 |
+
gates = gates.view(B, S, self.n_kv_groups, 3) # [B,S,G,3]
|
| 1389 |
+
self._update_gate_stats(gates)
|
| 1390 |
+
w_cmp = gates[..., 0:1].unsqueeze(3)
|
| 1391 |
+
w_sel = gates[..., 1:2].unsqueeze(3)
|
| 1392 |
+
w_win = gates[..., 2:3].unsqueeze(3)
|
| 1393 |
+
O = w_cmp * O_cmp + w_sel * O_sel + w_win * O_win # [B,S,G,h,Dv]
|
| 1394 |
+
|
| 1395 |
+
# Output projection
|
| 1396 |
+
O_heads = O.reshape(B, S, self.n_kv_groups * self.h_per_group, self.d_v)
|
| 1397 |
+
out = self.out(O_heads.reshape(B, S, -1))
|
| 1398 |
+
log("prefill.out", out=out)
|
| 1399 |
+
|
| 1400 |
+
# Optional debug compare: sequential-style per-token recompute to measure MAE
|
| 1401 |
+
if self._env_cache.get("debug_compare", False):
|
| 1402 |
+
with torch.no_grad():
|
| 1403 |
+
# Compressed per-token recompute
|
| 1404 |
+
O_cmp_seq = torch.zeros_like(O_cmp)
|
| 1405 |
+
S_cmp = kv.K_cmp.shape[2]
|
| 1406 |
+
for t in range(S):
|
| 1407 |
+
L = 0 if (t + 1) < self.l else min(((t + 1 - self.l) // self.d) + 1, S_cmp)
|
| 1408 |
+
|
| 1409 |
+
# M8: Assert causal masking in debug recompute
|
| 1410 |
+
if L > 0:
|
| 1411 |
+
assert L <= S_cmp, (
|
| 1412 |
+
f"Debug compressed range exceeds cache: L={L} > S_cmp={S_cmp} at t={t}"
|
| 1413 |
+
)
|
| 1414 |
+
|
| 1415 |
+
q_t = Q[:, t].contiguous()
|
| 1416 |
+
k_t = kv.K_cmp[:, :, :L, :].contiguous()
|
| 1417 |
+
v_t = kv.V_cmp[:, :, :L, :].contiguous()
|
| 1418 |
+
O_cmp_seq[:, t] = attention_bgh(q_t, k_t, v_t, causal=True)
|
| 1419 |
+
cmp_mae = (O_cmp - O_cmp_seq).abs().mean().item()
|
| 1420 |
+
print(f"NSA-DBG cmp_mae={cmp_mae:.6e}")
|
| 1421 |
+
|
| 1422 |
+
# Sliding per-token recompute
|
| 1423 |
+
O_win_seq = torch.zeros_like(O_win)
|
| 1424 |
+
for t in range(S):
|
| 1425 |
+
end = t + 1
|
| 1426 |
+
start = max(0, end - self.w)
|
| 1427 |
+
q_t = Q[:, t].contiguous()
|
| 1428 |
+
k_t = K_win[:, :, start:end, :].contiguous()
|
| 1429 |
+
v_t = V_win[:, :, start:end, :].contiguous()
|
| 1430 |
+
O_win_seq[:, t] = attention_bgh(q_t, k_t, v_t, causal=True)
|
| 1431 |
+
win_mae = (O_win - O_win_seq).abs().mean().item()
|
| 1432 |
+
print(f"NSA-DBG win_mae={win_mae:.6e}")
|
| 1433 |
+
|
| 1434 |
+
# Final output recompute using seq per-branch
|
| 1435 |
+
w_cmp_dbg = gates[..., 0:1].unsqueeze(-1)
|
| 1436 |
+
w_sel_dbg = gates[..., 1:2].unsqueeze(-1)
|
| 1437 |
+
w_win_dbg = gates[..., 2:3].unsqueeze(-1)
|
| 1438 |
+
O_seq = w_cmp_dbg * O_cmp_seq + w_sel_dbg * O_sel + w_win_dbg * O_win_seq
|
| 1439 |
+
O_heads_seq = O_seq.reshape(B, S, self.n_kv_groups * self.h_per_group, self.d_v)
|
| 1440 |
+
out_seq = self.out(O_heads_seq.reshape(B, S, -1))
|
| 1441 |
+
out_mae = (out - out_seq).abs().mean().item()
|
| 1442 |
+
print(f"NSA-DBG out_mae={out_mae:.6e}")
|
| 1443 |
+
return out, kv
|
| 1444 |
+
|
| 1445 |
+
def _audit_sdpa_backends_once(
|
| 1446 |
+
self,
|
| 1447 |
+
Q: torch.Tensor, # [B,1,G,h,Dk]
|
| 1448 |
+
K_sel: torch.Tensor, # [B,G,S,Dk]
|
| 1449 |
+
V_sel: torch.Tensor, # [B,G,S,Dv]
|
| 1450 |
+
K_win: torch.Tensor, # [B,G,S,Dk]
|
| 1451 |
+
V_win: torch.Tensor, # [B,G,S,Dv]
|
| 1452 |
+
) -> None:
|
| 1453 |
+
if self._sdpa_audited:
|
| 1454 |
+
return
|
| 1455 |
+
try:
|
| 1456 |
+
from torch.nn.attention import sdpa_kernel
|
| 1457 |
+
except Exception:
|
| 1458 |
+
# Older torch, skip audit
|
| 1459 |
+
self._sdpa_audited = True
|
| 1460 |
+
return
|
| 1461 |
+
B = Q.shape[0]
|
| 1462 |
+
G = self.n_kv_groups
|
| 1463 |
+
h = self.h_per_group
|
| 1464 |
+
# Prepare a small representative slice per branch
|
| 1465 |
+
q = Q[:, 0] # [B,G,h,Dk]
|
| 1466 |
+
# Ensure contiguity
|
| 1467 |
+
q = q.contiguous()
|
| 1468 |
+
ks = K_sel.contiguous()
|
| 1469 |
+
vs = V_sel.contiguous()
|
| 1470 |
+
kw = K_win.contiguous()
|
| 1471 |
+
vw = V_win.contiguous()
|
| 1472 |
+
|
| 1473 |
+
def _probe(tag: str, k: torch.Tensor, v: torch.Tensor) -> str:
|
| 1474 |
+
try:
|
| 1475 |
+
with sdpa_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
|
| 1476 |
+
q2 = q.reshape(B * G * h, 1, self.d_k)
|
| 1477 |
+
k2 = (
|
| 1478 |
+
k.unsqueeze(2)
|
| 1479 |
+
.expand(B, G, h, k.shape[2], self.d_k)
|
| 1480 |
+
.reshape(B * G * h, k.shape[2], self.d_k)
|
| 1481 |
+
)
|
| 1482 |
+
v2 = (
|
| 1483 |
+
v.unsqueeze(2)
|
| 1484 |
+
.expand(B, G, h, v.shape[2], self.d_v)
|
| 1485 |
+
.reshape(B * G * h, v.shape[2], self.d_v)
|
| 1486 |
+
)
|
| 1487 |
+
_ = F.scaled_dot_product_attention(
|
| 1488 |
+
q2.contiguous(), k2.contiguous(), v2.contiguous(), is_causal=True
|
| 1489 |
+
)
|
| 1490 |
+
return "flash"
|
| 1491 |
+
except Exception:
|
| 1492 |
+
return "fallback"
|
| 1493 |
+
|
| 1494 |
+
try:
|
| 1495 |
+
b_sel = _probe("cmp/win(sel)", ks, vs)
|
| 1496 |
+
b_win = _probe("win", kw, vw)
|
| 1497 |
+
log("sdpa.audit", sel=b_sel, win=b_win)
|
| 1498 |
+
except Exception:
|
| 1499 |
+
pass
|
| 1500 |
+
self._sdpa_audited = True
|
| 1501 |
+
|
| 1502 |
+
def _forward_prefill_via_decode(
|
| 1503 |
+
self, x: torch.Tensor, kv: NSA_KV
|
| 1504 |
+
) -> tuple[torch.Tensor, NSA_KV]:
|
| 1505 |
+
"""Prefill by stepping decode one token at a time.
|
| 1506 |
+
|
| 1507 |
+
This path avoids recursion back into prefill and guarantees progress.
|
| 1508 |
+
"""
|
| 1509 |
+
B, S, _ = x.shape
|
| 1510 |
+
outs = []
|
| 1511 |
+
for t in range(S):
|
| 1512 |
+
out_t, kv = self.forward(x[:, t : t + 1], kv, prefill=False)
|
| 1513 |
+
outs.append(out_t)
|
| 1514 |
+
return torch.cat(outs, dim=1), kv
|
| 1515 |
+
|
| 1516 |
+
def _forward_prefill_sequential(
|
| 1517 |
+
self, x: torch.Tensor, kv: NSA_KV
|
| 1518 |
+
) -> tuple[torch.Tensor, NSA_KV]:
|
| 1519 |
+
"""
|
| 1520 |
+
Reference prefill path (sequential per‑token), used for parity checks.
|
| 1521 |
+
"""
|
| 1522 |
+
B, S, _ = x.shape
|
| 1523 |
+
# Projections
|
| 1524 |
+
Q_lin = self._shape_q(self.W_Q(x), B, S) # [B,S,G,h,Dk]
|
| 1525 |
+
pos = torch.arange(S, device=x.device)
|
| 1526 |
+
Q = apply_rope(
|
| 1527 |
+
Q_lin.view(B, S, self.n_heads, self.d_k).reshape(B, S, self.n_heads * self.d_k),
|
| 1528 |
+
pos,
|
| 1529 |
+
scale=getattr(self, "rope_scale", 1.0),
|
| 1530 |
+
)
|
| 1531 |
+
Q = Q.view(B, S, self.n_heads, self.d_k).view(
|
| 1532 |
+
B, S, self.n_kv_groups, self.h_per_group, self.d_k
|
| 1533 |
+
)
|
| 1534 |
+
K_sel = self._shape_kv(self.W_K_sel(x), B, S)
|
| 1535 |
+
V_sel = self._shape_kv(self.W_V_sel(x), B, S)
|
| 1536 |
+
K_win = self._shape_kv(self.W_K_win(x), B, S)
|
| 1537 |
+
V_win = self._shape_kv(self.W_V_win(x), B, S)
|
| 1538 |
+
K_cmp_raw = self._shape_kv(self.W_K_cmp(x), B, S)
|
| 1539 |
+
V_cmp_raw = self._shape_kv(self.W_V_cmp(x), B, S)
|
| 1540 |
+
|
| 1541 |
+
# Apply RoPE to per-branch K tensors to align with batched path
|
| 1542 |
+
pos_k = torch.arange(S, device=x.device)
|
| 1543 |
+
K_sel = apply_rope(K_sel, pos_k, scale=getattr(self, "rope_scale", 1.0))
|
| 1544 |
+
K_win = apply_rope(K_win, pos_k, scale=getattr(self, "rope_scale", 1.0))
|
| 1545 |
+
|
| 1546 |
+
kv.update_selection_raw(K_sel, V_sel)
|
| 1547 |
+
kv.meta = build_block_meta(
|
| 1548 |
+
seq_len=S, l=self.l, d=self.d, l_sel=self.l_sel, n_sel=self.n_sel, w=self.w
|
| 1549 |
+
)
|
| 1550 |
+
kv.update_window(K_win, V_win, self.w)
|
| 1551 |
+
if self.phi_type == "mlp":
|
| 1552 |
+
K_cmp, V_cmp = self._phi_apply_seq(
|
| 1553 |
+
K_cmp_raw, V_cmp_raw, pos=torch.arange(S, device=x.device)
|
| 1554 |
+
)
|
| 1555 |
+
else:
|
| 1556 |
+
K_cmp, V_cmp = avg_pool_phi_rope_kv(
|
| 1557 |
+
K_cmp_raw, V_cmp_raw, self.l, self.d, pos=torch.arange(S, device=x.device)
|
| 1558 |
+
)
|
| 1559 |
+
kv.update_compressed(K_cmp, V_cmp, self.l, self.d)
|
| 1560 |
+
|
| 1561 |
+
# Precompute p_grp_all batched for reuse per t
|
| 1562 |
+
scale = 1.0 / (self.d_k**0.5)
|
| 1563 |
+
p_cmp_all = compute_pcmp_all(Q, kv.K_cmp, scale) # [B,S,G,h,S_cmp]
|
| 1564 |
+
p_slc_all = map_pcmp_to_pslc_batched(p_cmp_all, kv.meta) # [B,S,G,h,S_sel]
|
| 1565 |
+
p_grp_all = p_slc_all.sum(dim=3) # [B,S,G,S_sel]
|
| 1566 |
+
|
| 1567 |
+
outs = []
|
| 1568 |
+
sel_ranges_accum: List[torch.Tensor] = []
|
| 1569 |
+
for t in range(S):
|
| 1570 |
+
p_grp = p_grp_all[:, t] # [B,G,S_sel]
|
| 1571 |
+
sel_ranges = select_topn_ranges(p_grp, kv.meta, self.n_sel, t, True, 2)
|
| 1572 |
+
sel_ranges_accum.append(sel_ranges)
|
| 1573 |
+
Q_t = Q[:, t]
|
| 1574 |
+
K_sel_t = kv.K_sel[:, :, : t + 1, :]
|
| 1575 |
+
V_sel_t = kv.V_sel[:, :, : t + 1, :]
|
| 1576 |
+
# Selection attention routing (mirror decode/batched semantics)
|
| 1577 |
+
force_parity = self._env_cache.get("force_parity", False)
|
| 1578 |
+
use_sel_pack = self._env_cache.get("use_sel_pack", True) and not force_parity
|
| 1579 |
+
use_triton_sel = self._env_cache.get("use_triton_sel", False) and not force_parity
|
| 1580 |
+
use_cuda_sel = self._env_cache.get("use_cuda_sel", False) and not force_parity
|
| 1581 |
+
force_sel_mask = self._env_cache.get("force_sel_mask", False) and not force_parity
|
| 1582 |
+
if force_sel_mask:
|
| 1583 |
+
try:
|
| 1584 |
+
O_sel_bt = grouped_selection_attention_masked(
|
| 1585 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 1586 |
+
)
|
| 1587 |
+
O_sel = O_sel_bt[:, 0]
|
| 1588 |
+
log("prefill.sel.path", path="masked_forced")
|
| 1589 |
+
except Exception as e:
|
| 1590 |
+
self._fallback_counters["selection_mask_fails"] += 1
|
| 1591 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1592 |
+
log("warn.masked_selection_prefill_forced_fallback",
|
| 1593 |
+
error=str(e),
|
| 1594 |
+
Q_shape=list(Q.shape) if hasattr(Q, 'shape') else list(Q_t.shape),
|
| 1595 |
+
K_shape=list(kv.K_sel.shape) if hasattr(kv, 'K_sel') else list(K_sel_t.shape),
|
| 1596 |
+
V_shape=list(kv.V_sel.shape) if hasattr(kv, 'V_sel') else list(V_sel_t.shape),
|
| 1597 |
+
ranges_shape=list(sel_ranges_all.shape) if 'sel_ranges_all' in locals() else list(sel_ranges.shape) if sel_ranges is not None else None,
|
| 1598 |
+
total_fails=self._fallback_counters["selection_mask_fails"])
|
| 1599 |
+
O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
|
| 1600 |
+
elif use_triton_sel:
|
| 1601 |
+
try:
|
| 1602 |
+
from nsa.kernels.triton_sel_kernel import selection_attention_triton
|
| 1603 |
+
|
| 1604 |
+
O_sel_bt = selection_attention_triton(
|
| 1605 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 1606 |
+
)
|
| 1607 |
+
O_sel = O_sel_bt[:, 0]
|
| 1608 |
+
log("prefill.sel.path", path="triton")
|
| 1609 |
+
except Exception as e:
|
| 1610 |
+
# Fallback counter - Triton selection failed (sequential prefill)
|
| 1611 |
+
self._fallback_counters["selection_triton_fails"] += 1
|
| 1612 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1613 |
+
log(
|
| 1614 |
+
"warn.triton_selection_prefill_fallback",
|
| 1615 |
+
error=str(e),
|
| 1616 |
+
total_fails=self._fallback_counters["selection_triton_fails"],
|
| 1617 |
+
)
|
| 1618 |
+
# Fallback to packed SDPA
|
| 1619 |
+
O_sel_bt = grouped_selection_attention_packed(
|
| 1620 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 1621 |
+
)
|
| 1622 |
+
O_sel = O_sel_bt[:, 0]
|
| 1623 |
+
elif use_cuda_sel:
|
| 1624 |
+
try:
|
| 1625 |
+
from nsa.kernels.cuda_sel_kernel import selection_attention_cuda
|
| 1626 |
+
|
| 1627 |
+
O_sel_bt = selection_attention_cuda(
|
| 1628 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 1629 |
+
)
|
| 1630 |
+
O_sel = O_sel_bt[:, 0]
|
| 1631 |
+
except Exception as e:
|
| 1632 |
+
# Fallback counter - CUDA selection failed (sequential prefill)
|
| 1633 |
+
self._fallback_counters["selection_cuda_fails"] += 1
|
| 1634 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1635 |
+
log(
|
| 1636 |
+
"warn.cuda_selection_prefill_fallback",
|
| 1637 |
+
error=str(e),
|
| 1638 |
+
total_fails=self._fallback_counters["selection_cuda_fails"],
|
| 1639 |
+
)
|
| 1640 |
+
# Fallback to packed SDPA
|
| 1641 |
+
O_sel_bt = grouped_selection_attention_packed(
|
| 1642 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 1643 |
+
)
|
| 1644 |
+
O_sel = O_sel_bt[:, 0]
|
| 1645 |
+
elif use_sel_pack:
|
| 1646 |
+
try:
|
| 1647 |
+
O_sel_bt = grouped_selection_attention_packed(
|
| 1648 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 1649 |
+
)
|
| 1650 |
+
O_sel = O_sel_bt[:, 0]
|
| 1651 |
+
log("prefill.sel.path", path="packed")
|
| 1652 |
+
except Exception as e:
|
| 1653 |
+
# Fallback counter - Packed selection failed (sequential prefill)
|
| 1654 |
+
self._fallback_counters["selection_pack_fails"] += 1
|
| 1655 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1656 |
+
log(
|
| 1657 |
+
"warn.packed_selection_prefill_fallback",
|
| 1658 |
+
error=str(e),
|
| 1659 |
+
total_fails=self._fallback_counters["selection_pack_fails"],
|
| 1660 |
+
)
|
| 1661 |
+
# Fallback to gather SDPA
|
| 1662 |
+
O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
|
| 1663 |
+
elif self._env_cache.get("use_sel_mask", False) and not force_parity:
|
| 1664 |
+
try:
|
| 1665 |
+
O_sel_bt = grouped_selection_attention_masked(
|
| 1666 |
+
Q_t.unsqueeze(1), K_sel_t, V_sel_t, sel_ranges.unsqueeze(1)
|
| 1667 |
+
)
|
| 1668 |
+
O_sel = O_sel_bt[:, 0]
|
| 1669 |
+
log("prefill.sel.path", path="masked")
|
| 1670 |
+
except Exception as e:
|
| 1671 |
+
# Fallback counter - Masked selection failed (sequential prefill)
|
| 1672 |
+
self._fallback_counters["selection_mask_fails"] += 1
|
| 1673 |
+
self._fallback_counters["total_fallbacks"] += 1
|
| 1674 |
+
log(
|
| 1675 |
+
"warn.masked_selection_prefill_fallback",
|
| 1676 |
+
error=str(e),
|
| 1677 |
+
total_fails=self._fallback_counters["selection_mask_fails"],
|
| 1678 |
+
)
|
| 1679 |
+
# Fallback to gather SDPA
|
| 1680 |
+
O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
|
| 1681 |
+
else:
|
| 1682 |
+
O_sel = self._sdpa_over_ranges(Q_t, K_sel_t, V_sel_t, sel_ranges)
|
| 1683 |
+
win_len = min(self.w, t + 1)
|
| 1684 |
+
K_w = kv.K_win[:, :, t + 1 - win_len : t + 1, :]
|
| 1685 |
+
V_w = kv.V_win[:, :, t + 1 - win_len : t + 1, :]
|
| 1686 |
+
O_win = attention_bgh(Q_t.contiguous(), K_w.contiguous(), V_w.contiguous(), causal=True)
|
| 1687 |
+
S_cmp_t = 0 if (t + 1) < self.l else (t + 1 - self.l) // self.d + 1
|
| 1688 |
+
O_cmp = attention_bgh(
|
| 1689 |
+
Q_t.contiguous(),
|
| 1690 |
+
kv.K_cmp[:, :, :S_cmp_t, :].contiguous(),
|
| 1691 |
+
kv.V_cmp[:, :, :S_cmp_t, :].contiguous(),
|
| 1692 |
+
causal=True,
|
| 1693 |
+
)
|
| 1694 |
+
q_gp = Q_t.mean(dim=2, dtype=Q_t.dtype)
|
| 1695 |
+
gates = self.gate(q_gp, tau=self.gate_temp)
|
| 1696 |
+
if self._env_cache.get("stopgrad_gates", False):
|
| 1697 |
+
gates = gates.detach()
|
| 1698 |
+
|
| 1699 |
+
# Update gate statistics for M8 monitoring (accumulate across steps)
|
| 1700 |
+
self._update_gate_stats(gates)
|
| 1701 |
+
|
| 1702 |
+
w_cmp = gates[..., 0:1].unsqueeze(-1)
|
| 1703 |
+
w_sel = gates[..., 1:2].unsqueeze(-1)
|
| 1704 |
+
w_win = gates[..., 2:3].unsqueeze(-1)
|
| 1705 |
+
O = w_cmp * O_cmp + w_sel * O_sel + w_win * O_win
|
| 1706 |
+
O_heads = O.reshape(B, self.n_heads, self.d_v)
|
| 1707 |
+
out_t = self.out(O_heads.reshape(B, 1, -1))
|
| 1708 |
+
outs.append(out_t)
|
| 1709 |
+
out = torch.cat(outs, dim=1)
|
| 1710 |
+
# Aggregate selection stats across all t in this prefill (sequential path)
|
| 1711 |
+
try:
|
| 1712 |
+
if sel_ranges_accum:
|
| 1713 |
+
# Stack to [T,B,G,n,2] then permute to [B,T,G,n,2]
|
| 1714 |
+
rs = torch.stack(sel_ranges_accum, dim=0).permute(1, 0, 2, 3, 4)
|
| 1715 |
+
self._update_sel_stats_from_ranges(rs)
|
| 1716 |
+
except Exception:
|
| 1717 |
+
pass
|
| 1718 |
+
return out, kv
|
| 1719 |
+
|
| 1720 |
+
def _sdpa_full(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
|
| 1721 |
+
# Q: [B,G,h,Dk]; K/V: [B,G,S,D*] -> out [B,G,h,Dv]
|
| 1722 |
+
B, G, h, Dk = Q.shape
|
| 1723 |
+
S = K.shape[2]
|
| 1724 |
+
q = Q.reshape(B * G * h, 1, Dk).contiguous()
|
| 1725 |
+
k = K.unsqueeze(2).expand(B, G, h, S, Dk).reshape(B * G * h, S, Dk).contiguous()
|
| 1726 |
+
v = (
|
| 1727 |
+
V.unsqueeze(2)
|
| 1728 |
+
.expand(B, G, h, S, V.shape[-1])
|
| 1729 |
+
.reshape(B * G * h, S, V.shape[-1])
|
| 1730 |
+
.contiguous()
|
| 1731 |
+
)
|
| 1732 |
+
attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
| 1733 |
+
o = attn.squeeze(1).reshape(B, G, h, -1)
|
| 1734 |
+
return o
|
| 1735 |
+
|
| 1736 |
+
def _phi_apply_seq(
|
| 1737 |
+
self, K_raw: torch.Tensor, V_raw: torch.Tensor, pos: torch.Tensor
|
| 1738 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1739 |
+
"""Apply learnable ϕ over the full sequence using depthwise Conv1d initialized to avg.
|
| 1740 |
+
Expects K_raw,V_raw: [B,G,S,D*]; returns [B,G,S_cmp,D*].
|
| 1741 |
+
"""
|
| 1742 |
+
assert self.phi_k_conv is not None and self.phi_v_conv is not None
|
| 1743 |
+
B, G, S, Dk = K_raw.shape
|
| 1744 |
+
Dv = V_raw.shape[-1]
|
| 1745 |
+
K_rope = apply_rope(K_raw, pos, scale=getattr(self, "rope_scale", 1.0))
|
| 1746 |
+
Kx = K_rope.permute(0, 1, 3, 2).reshape(B * G, Dk, S)
|
| 1747 |
+
Vx = V_raw.permute(0, 1, 3, 2).reshape(B * G, Dv, S)
|
| 1748 |
+
Kc = self.phi_k_conv(Kx)
|
| 1749 |
+
Vc = self.phi_v_conv(Vx)
|
| 1750 |
+
S_cmp = Kc.shape[-1]
|
| 1751 |
+
K_cmp = Kc.reshape(B, G, Dk, S_cmp).permute(0, 1, 3, 2).contiguous()
|
| 1752 |
+
V_cmp = Vc.reshape(B, G, Dv, S_cmp).permute(0, 1, 3, 2).contiguous()
|
| 1753 |
+
return K_cmp, V_cmp
|
| 1754 |
+
|
| 1755 |
+
def _phi_apply_last(
|
| 1756 |
+
self, K_last: torch.Tensor, V_last: torch.Tensor, pos_last: torch.Tensor
|
| 1757 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1758 |
+
"""Emit a single compressed token from the last l raw tokens using Conv1d with kernel=l,stride=d.
|
| 1759 |
+
Inputs: [B,G,l,D*] -> Outputs: [B,G,1,D*].
|
| 1760 |
+
"""
|
| 1761 |
+
assert self.phi_k_conv is not None and self.phi_v_conv is not None
|
| 1762 |
+
B, G, lwin, Dk = K_last.shape
|
| 1763 |
+
Dv = V_last.shape[-1]
|
| 1764 |
+
assert lwin == self.l, "decode emission expects exactly l tokens"
|
| 1765 |
+
K_rope = apply_rope(K_last, pos_last, scale=getattr(self, "rope_scale", 1.0))
|
| 1766 |
+
Kx = K_rope.permute(0, 1, 3, 2).reshape(B * G, Dk, lwin)
|
| 1767 |
+
Vx = V_last.permute(0, 1, 3, 2).reshape(B * G, Dv, lwin)
|
| 1768 |
+
Kc = self.phi_k_conv(Kx)
|
| 1769 |
+
Vc = self.phi_v_conv(Vx)
|
| 1770 |
+
K_cmp_new = Kc.reshape(B, G, Dk, 1).permute(0, 1, 3, 2).contiguous()
|
| 1771 |
+
V_cmp_new = Vc.reshape(B, G, Dv, 1).permute(0, 1, 3, 2).contiguous()
|
| 1772 |
+
return K_cmp_new, V_cmp_new
|
| 1773 |
+
|
| 1774 |
+
def _sdpa_over_ranges(
|
| 1775 |
+
self,
|
| 1776 |
+
Q: torch.Tensor,
|
| 1777 |
+
K: torch.Tensor,
|
| 1778 |
+
V: torch.Tensor,
|
| 1779 |
+
ranges: torch.Tensor,
|
| 1780 |
+
) -> torch.Tensor:
|
| 1781 |
+
"""
|
| 1782 |
+
SDPA over concatenated gathered tokens per (B,G) according to `ranges`.
|
| 1783 |
+
|
| 1784 |
+
Args:
|
| 1785 |
+
Q: [B,G,h,Dk]
|
| 1786 |
+
K: [B,G,S_kv,Dk]
|
| 1787 |
+
V: [B,G,S_kv,Dv]
|
| 1788 |
+
ranges: [B,G,n,2] start/end pairs
|
| 1789 |
+
Returns:
|
| 1790 |
+
[B,G,h,Dv]
|
| 1791 |
+
"""
|
| 1792 |
+
# Concatenate gathered tokens per (B,G)
|
| 1793 |
+
B, G, h, Dk = Q.shape
|
| 1794 |
+
Dv = V.shape[-1]
|
| 1795 |
+
outs = []
|
| 1796 |
+
S_kv = K.shape[2]
|
| 1797 |
+
strict_asserts = (
|
| 1798 |
+
self._env_cache.get("strict_asserts", False) if hasattr(self, "_env_cache") else False
|
| 1799 |
+
)
|
| 1800 |
+
for b in range(B):
|
| 1801 |
+
row = []
|
| 1802 |
+
for g in range(G):
|
| 1803 |
+
# Clamp and validate ranges to avoid invalid or oversized indices
|
| 1804 |
+
r = ranges[b, g].to(dtype=torch.int64, device=K.device) # [n,2]
|
| 1805 |
+
if r.numel() == 0:
|
| 1806 |
+
valid_pairs = torch.empty((0, 2), dtype=torch.int64, device=K.device)
|
| 1807 |
+
else:
|
| 1808 |
+
s = r[:, 0].clamp_(0, S_kv)
|
| 1809 |
+
e = r[:, 1].clamp_(0, S_kv)
|
| 1810 |
+
valid = e > s
|
| 1811 |
+
valid_pairs = torch.stack([s[valid], e[valid]], dim=-1)
|
| 1812 |
+
|
| 1813 |
+
# M8: Assert bounds for gathered ranges (GPU-sync gated)
|
| 1814 |
+
if strict_asserts and valid_pairs.numel() > 0:
|
| 1815 |
+
max_end = valid_pairs[:, 1].max().item()
|
| 1816 |
+
assert max_end <= S_kv, (
|
| 1817 |
+
f"Selection range exceeds sequence length: max_end={max_end} > S_kv={S_kv} "
|
| 1818 |
+
f"at batch={b}, group={g}."
|
| 1819 |
+
)
|
| 1820 |
+
# Build a boolean mask over S_kv to gather selected tokens (limits worst-case size)
|
| 1821 |
+
if valid_pairs.numel() > 0:
|
| 1822 |
+
m = torch.zeros((S_kv,), dtype=torch.bool, device=K.device)
|
| 1823 |
+
for s_e in valid_pairs:
|
| 1824 |
+
s_i = int(s_e[0].item())
|
| 1825 |
+
e_i = int(s_e[1].item())
|
| 1826 |
+
if e_i > s_i:
|
| 1827 |
+
m[s_i:e_i] = True
|
| 1828 |
+
idx = m.nonzero(as_tuple=False).squeeze(-1)
|
| 1829 |
+
else:
|
| 1830 |
+
idx = torch.empty((0,), dtype=torch.int64, device=K.device)
|
| 1831 |
+
k = (
|
| 1832 |
+
K[b, g, idx]
|
| 1833 |
+
if idx.numel() > 0
|
| 1834 |
+
else torch.zeros((1, Dk), device=K.device, dtype=K.dtype)
|
| 1835 |
+
)
|
| 1836 |
+
v = (
|
| 1837 |
+
V[b, g, idx]
|
| 1838 |
+
if idx.numel() > 0
|
| 1839 |
+
else torch.zeros((1, Dv), device=K.device, dtype=V.dtype)
|
| 1840 |
+
)
|
| 1841 |
+
q = Q[b, g] # [h,Dk]
|
| 1842 |
+
attn = F.scaled_dot_product_attention(
|
| 1843 |
+
q.unsqueeze(0).contiguous(),
|
| 1844 |
+
k.unsqueeze(0).contiguous(),
|
| 1845 |
+
v.unsqueeze(0).contiguous(),
|
| 1846 |
+
is_causal=True,
|
| 1847 |
+
)
|
| 1848 |
+
row.append(attn.squeeze(0)) # [h,Dv]
|
| 1849 |
+
outs.append(torch.stack(row, dim=0)) # [G,h,Dv]
|
| 1850 |
+
return torch.stack(outs, dim=0) # [B,G,h,Dv]
|
nsa/core/packing.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def compute_sliding_lengths(S: int, w: int, device: torch.device) -> torch.Tensor:
|
| 8 |
+
"""
|
| 9 |
+
Return per-row window lengths for sliding attention: L_t = min(w, t+1)
|
| 10 |
+
Shape: [S]
|
| 11 |
+
"""
|
| 12 |
+
tpos = torch.arange(S, device=device)
|
| 13 |
+
return (tpos + 1).clamp_max(w)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def compute_compressed_lengths(
|
| 17 |
+
S: int, l: int, d: int, S_cmp: int, device: torch.device
|
| 18 |
+
) -> torch.Tensor:
|
| 19 |
+
"""
|
| 20 |
+
Return per-row valid compressed lengths: num_cmp(t)
|
| 21 |
+
Shape: [S]
|
| 22 |
+
"""
|
| 23 |
+
tpos = torch.arange(S, device=device)
|
| 24 |
+
return torch.where(tpos + 1 < l, 0, ((tpos + 1 - l) // d) + 1).clamp(min=0, max=S_cmp)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def build_length_buckets(lengths: torch.Tensor) -> List[torch.Tensor]:
|
| 28 |
+
"""
|
| 29 |
+
Group row indices by identical length.
|
| 30 |
+
Args:
|
| 31 |
+
lengths: [S] int tensor
|
| 32 |
+
Returns:
|
| 33 |
+
List of index tensors, one per unique length (descending by length)
|
| 34 |
+
"""
|
| 35 |
+
if lengths.numel() == 0:
|
| 36 |
+
return []
|
| 37 |
+
unique = torch.unique(lengths, sorted=True)
|
| 38 |
+
# sort descending so larger buckets processed first
|
| 39 |
+
unique = torch.flip(unique, dims=[0])
|
| 40 |
+
buckets: List[torch.Tensor] = []
|
| 41 |
+
for L in unique.tolist():
|
| 42 |
+
idx = torch.nonzero(lengths == int(L), as_tuple=False).flatten()
|
| 43 |
+
buckets.append(idx)
|
| 44 |
+
return buckets
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_cu_seqlens_for_buckets(bucket_lengths: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
Build cumulative sequence lengths (cu_seqlens) for varlen APIs from a vector of lengths.
|
| 50 |
+
Args:
|
| 51 |
+
bucket_lengths: [N] lengths per row in a bucket
|
| 52 |
+
Returns:
|
| 53 |
+
cu_seqlens: [N+1] with cu_seqlens[0]=0 and cu_seqlens[i+1]=sum_{j<=i} len[j]
|
| 54 |
+
"""
|
| 55 |
+
if bucket_lengths.numel() == 0:
|
| 56 |
+
return torch.zeros((1,), dtype=torch.int32, device=bucket_lengths.device)
|
| 57 |
+
cs = torch.zeros((bucket_lengths.numel() + 1,), dtype=torch.int32, device=bucket_lengths.device)
|
| 58 |
+
cs[1:] = torch.cumsum(bucket_lengths.to(dtype=torch.int32), dim=0)
|
| 59 |
+
return cs
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def pack_batch_by_lengths(
|
| 63 |
+
x: torch.Tensor, lengths: torch.Tensor
|
| 64 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 65 |
+
"""
|
| 66 |
+
Pack a batch of padded rows into a contiguous buffer with cu_seqlens.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
x: [B,S_max,D]
|
| 70 |
+
lengths: [B] valid lengths per row
|
| 71 |
+
Returns:
|
| 72 |
+
(packed: [sum(lengths), D], cu_seqlens: [B+1])
|
| 73 |
+
"""
|
| 74 |
+
device = x.device
|
| 75 |
+
B, S_max, D = x.shape
|
| 76 |
+
assert lengths.shape[0] == B
|
| 77 |
+
cu = build_cu_seqlens_for_buckets(lengths.to(torch.int32))
|
| 78 |
+
N = int(cu[-1].item())
|
| 79 |
+
packed = torch.empty((N, D), dtype=x.dtype, device=device)
|
| 80 |
+
write = 0
|
| 81 |
+
for b in range(B):
|
| 82 |
+
L = int(lengths[b].item())
|
| 83 |
+
if L > 0:
|
| 84 |
+
packed[write : write + L] = x[b, :L]
|
| 85 |
+
write += L
|
| 86 |
+
return packed, cu
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def unpack_packed_to_padded(
|
| 90 |
+
packed: torch.Tensor, cu_seqlens: torch.Tensor, S_max: int
|
| 91 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 92 |
+
"""
|
| 93 |
+
Unpack a packed buffer back to padded batch and mask.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
packed: [N,D]
|
| 97 |
+
cu_seqlens: [B+1]
|
| 98 |
+
S_max: target padded length
|
| 99 |
+
Returns:
|
| 100 |
+
(padded [B,S_max,D], mask [B,S_max])
|
| 101 |
+
"""
|
| 102 |
+
device = packed.device
|
| 103 |
+
B = cu_seqlens.shape[0] - 1
|
| 104 |
+
D = packed.shape[-1]
|
| 105 |
+
padded = torch.zeros((B, S_max, D), dtype=packed.dtype, device=device)
|
| 106 |
+
mask = torch.zeros((B, S_max), dtype=torch.bool, device=device)
|
| 107 |
+
for b in range(B):
|
| 108 |
+
start = int(cu_seqlens[b].item())
|
| 109 |
+
end = int(cu_seqlens[b + 1].item())
|
| 110 |
+
L = end - start
|
| 111 |
+
if L > 0:
|
| 112 |
+
padded[b, :L] = packed[start:end]
|
| 113 |
+
mask[b, :L] = True
|
| 114 |
+
return padded, mask
|
nsa/core/rope.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def build_inv_freq(
|
| 7 |
+
dim: int, base: float = 10000.0, device: torch.device | None = None
|
| 8 |
+
) -> torch.Tensor:
|
| 9 |
+
assert dim % 2 == 0, "RoPE requires even dimension"
|
| 10 |
+
half = dim // 2
|
| 11 |
+
idx = torch.arange(half, device=device, dtype=torch.float32)
|
| 12 |
+
inv_freq = base ** (-2 * idx / dim)
|
| 13 |
+
return inv_freq # [half]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def apply_rope(
|
| 17 |
+
x: torch.Tensor,
|
| 18 |
+
pos: torch.Tensor,
|
| 19 |
+
base: float = 10000.0,
|
| 20 |
+
*,
|
| 21 |
+
scale: float = 1.0,
|
| 22 |
+
) -> torch.Tensor:
|
| 23 |
+
"""
|
| 24 |
+
Apply rotary position embeddings along the last dimension.
|
| 25 |
+
|
| 26 |
+
x: [..., S, D] tensor with even D
|
| 27 |
+
pos: [S] or [..., S] integer positions
|
| 28 |
+
returns: same shape as x
|
| 29 |
+
"""
|
| 30 |
+
D = x.shape[-1]
|
| 31 |
+
assert D % 2 == 0, "RoPE requires even dimension"
|
| 32 |
+
device = x.device
|
| 33 |
+
inv_freq = build_inv_freq(D, base=base, device=device) # [D/2]
|
| 34 |
+
# pos shape broadcasting to [..., S, D/2]
|
| 35 |
+
while pos.dim() < x.dim() - 1:
|
| 36 |
+
pos = pos.unsqueeze(0)
|
| 37 |
+
# Simple NTK/YARN-style extension via position scaling: effective_pos = pos / scale
|
| 38 |
+
if scale <= 0:
|
| 39 |
+
scale = 1.0
|
| 40 |
+
# Compute angles in float32 for accuracy, then cast sin/cos to input dtype to preserve dtype end-to-end
|
| 41 |
+
angles = (pos.to(torch.float32) / float(scale)).unsqueeze(
|
| 42 |
+
-1
|
| 43 |
+
) * inv_freq # [..., S, D/2] (float32)
|
| 44 |
+
sin = torch.sin(angles).to(dtype=x.dtype)
|
| 45 |
+
cos = torch.cos(angles).to(dtype=x.dtype)
|
| 46 |
+
x_2 = x.view(*x.shape[:-1], D // 2, 2)
|
| 47 |
+
x0, x1 = x_2[..., 0], x_2[..., 1]
|
| 48 |
+
y0 = x0 * cos - x1 * sin
|
| 49 |
+
y1 = x0 * sin + x1 * cos
|
| 50 |
+
y = torch.stack((y0, y1), dim=-1).view_as(x)
|
| 51 |
+
return y
|
nsa/core/selection_scorer.py
ADDED
|
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from .block_index import BlockMeta
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def compute_pcmp(Q: torch.Tensor, K_cmp: torch.Tensor, scale: float) -> torch.Tensor:
|
| 12 |
+
# Q: [G,h,Dk]; K_cmp: [B,G,S_cmp,Dk] with implicit B=1 for this path
|
| 13 |
+
if Q.dim() == 3:
|
| 14 |
+
# Q: [G,h,Dk]; K_cmp: [1,G,S_cmp,Dk] (implicit B=1)
|
| 15 |
+
G, h, Dk = Q.shape
|
| 16 |
+
S_cmp = K_cmp.shape[2]
|
| 17 |
+
q = Q.reshape(G * h, 1, Dk)
|
| 18 |
+
# Expand K over heads without materializing copies
|
| 19 |
+
k = (
|
| 20 |
+
K_cmp[0]
|
| 21 |
+
.unsqueeze(1) # [G,1,S_cmp,Dk]
|
| 22 |
+
.expand(G, h, S_cmp, Dk)
|
| 23 |
+
.reshape(G * h, S_cmp, Dk)
|
| 24 |
+
)
|
| 25 |
+
logits = torch.bmm(q, k.transpose(1, 2)).squeeze(1) * scale
|
| 26 |
+
return F.softmax(logits, dim=-1).reshape(1, G, h, S_cmp)
|
| 27 |
+
else:
|
| 28 |
+
# Q: [B,G,h,Dk]; K_cmp: [B,G,S_cmp,Dk]
|
| 29 |
+
B, G, h, Dk = Q.shape
|
| 30 |
+
S_cmp = K_cmp.shape[2]
|
| 31 |
+
q = Q.reshape(B * G * h, 1, Dk)
|
| 32 |
+
# Expand K over heads without materializing copies
|
| 33 |
+
k = (
|
| 34 |
+
K_cmp.unsqueeze(2) # [B,G,1,S_cmp,Dk]
|
| 35 |
+
.expand(B, G, h, S_cmp, Dk)
|
| 36 |
+
.reshape(B * G * h, S_cmp, Dk)
|
| 37 |
+
)
|
| 38 |
+
logits = torch.bmm(q, k.transpose(1, 2)).squeeze(1) * scale
|
| 39 |
+
p = F.softmax(logits, dim=-1)
|
| 40 |
+
return p.reshape(B, G, h, S_cmp)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def compute_pcmp_all(Q_all: torch.Tensor, K_cmp: torch.Tensor, scale: float) -> torch.Tensor:
|
| 44 |
+
"""
|
| 45 |
+
Q_all: [B,S,G,h,Dk], K_cmp: [B,G,S_cmp,Dk] -> p_cmp_all: [B,S,G,h,S_cmp]
|
| 46 |
+
"""
|
| 47 |
+
use_mixed = os.getenv("NSA_P_CMP_MIXED", "0").lower() in ("1", "true", "yes", "on")
|
| 48 |
+
if use_mixed and Q_all.device.type == "cuda":
|
| 49 |
+
# Optional mixed-precision path (disabled by default). Computes logits and softmax
|
| 50 |
+
# under autocast to reduce memory bandwidth on large shapes. Output is upcast
|
| 51 |
+
# back to the original dtype to preserve downstream numerics.
|
| 52 |
+
orig_dtype = Q_all.dtype
|
| 53 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 54 |
+
Kt = K_cmp.permute(0, 1, 3, 2) # [B,G,Dk,S_cmp]
|
| 55 |
+
logits = torch.einsum("bsghd,bgdc->bsghc", Q_all, Kt) * scale
|
| 56 |
+
p = F.softmax(logits, dim=-1)
|
| 57 |
+
return p.to(orig_dtype)
|
| 58 |
+
else:
|
| 59 |
+
# Baseline precise path
|
| 60 |
+
Kt = K_cmp.permute(0, 1, 3, 2) # [B,G,Dk,S_cmp]
|
| 61 |
+
logits = torch.einsum("bsghd,bgdc->bsghc", Q_all, Kt) * scale
|
| 62 |
+
return F.softmax(logits, dim=-1)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def map_pcmp_to_pslc(p_cmp: torch.Tensor, meta: BlockMeta) -> torch.Tensor:
|
| 66 |
+
# p_cmp: [B,G,h,S_cmp]
|
| 67 |
+
B, G, h, S_cmp = p_cmp.shape
|
| 68 |
+
indptr = meta.M_csl_indptr
|
| 69 |
+
indices = meta.M_csl_indices
|
| 70 |
+
values = meta.M_csl_values
|
| 71 |
+
S_sel = meta.sel_starts.numel()
|
| 72 |
+
device = p_cmp.device
|
| 73 |
+
# Out-of-place accumulation to avoid in-place versioning issues under GC/DDP
|
| 74 |
+
p_slc = torch.zeros((B, G, h, S_sel), device=device, dtype=p_cmp.dtype)
|
| 75 |
+
acc = torch.zeros_like(p_slc)
|
| 76 |
+
# CSR row-wise multiply-add
|
| 77 |
+
for r in range(S_cmp):
|
| 78 |
+
start, end = int(indptr[r].item()), int(indptr[r + 1].item())
|
| 79 |
+
if start == end:
|
| 80 |
+
continue
|
| 81 |
+
cols = indices[start:end].to(device)
|
| 82 |
+
w = values[start:end].to(device=device, dtype=p_cmp.dtype) # [nnz_r]
|
| 83 |
+
contrib = p_cmp[..., r].unsqueeze(-1) * w # [B,G,h,nnz_r]
|
| 84 |
+
# Ensure Long dtype for scatter_add indices
|
| 85 |
+
idx = cols.view(1, 1, 1, -1).expand(B, G, h, -1).long()
|
| 86 |
+
acc = acc.scatter_add(-1, idx, contrib)
|
| 87 |
+
return acc
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def map_pcmp_to_pslc_batched(p_cmp_all: torch.Tensor, meta: BlockMeta) -> torch.Tensor:
|
| 91 |
+
"""
|
| 92 |
+
p_cmp_all: [B,S,G,h,S_cmp] -> p_slc_all: [B,S,G,h,S_sel]
|
| 93 |
+
Vectorized over B,S,G,h while looping CSR rows over S_cmp.
|
| 94 |
+
"""
|
| 95 |
+
B, S, G, h, S_cmp = p_cmp_all.shape
|
| 96 |
+
device = p_cmp_all.device
|
| 97 |
+
S_sel = meta.sel_starts.numel()
|
| 98 |
+
if S_cmp == 0:
|
| 99 |
+
return torch.zeros((B, S, G, h, S_sel), device=device, dtype=p_cmp_all.dtype)
|
| 100 |
+
# COO sparse matmul: for each nnz (r,c,w), add p_cmp[..., r]*w to p_slc[..., c]
|
| 101 |
+
rows, cols = meta.M_csl_coo_indices.to(device)
|
| 102 |
+
w = meta.M_csl_coo_values.to(device=device, dtype=p_cmp_all.dtype)
|
| 103 |
+
# Filter mapping rows to those < current S_cmp to avoid out-of-bounds in early decode
|
| 104 |
+
valid_mask = rows < S_cmp
|
| 105 |
+
if valid_mask.dim() == 0:
|
| 106 |
+
valid_mask = valid_mask.unsqueeze(0)
|
| 107 |
+
rows = rows[valid_mask]
|
| 108 |
+
cols = cols[valid_mask]
|
| 109 |
+
w = w[valid_mask]
|
| 110 |
+
if rows.numel() == 0:
|
| 111 |
+
return torch.zeros((B, S, G, h, S_sel), device=device, dtype=p_cmp_all.dtype)
|
| 112 |
+
p_src = p_cmp_all[..., rows] * w # [B,S,G,h,nnz]
|
| 113 |
+
p_slc = torch.zeros((B, S, G, h, S_sel), device=device, dtype=p_cmp_all.dtype)
|
| 114 |
+
# Ensure Long dtype for scatter_add indices
|
| 115 |
+
idx = cols.view(1, 1, 1, 1, -1).expand(B, S, G, h, -1).long()
|
| 116 |
+
p_slc = p_slc.scatter_add(-1, idx, p_src)
|
| 117 |
+
return p_slc
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def group_reduce_pslc(p_slc: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
# Sum across heads in group (Eq. 10)
|
| 122 |
+
return p_slc.sum(dim=2)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def select_topn_ranges(
|
| 126 |
+
p_grp: torch.Tensor,
|
| 127 |
+
meta: BlockMeta,
|
| 128 |
+
n_top: int,
|
| 129 |
+
t_token: int,
|
| 130 |
+
force_init: bool = True,
|
| 131 |
+
force_local: int = 2,
|
| 132 |
+
_skip_validation: bool = False,
|
| 133 |
+
) -> torch.Tensor:
|
| 134 |
+
"""Select top-n block ranges with deterministic tie-breaking.
|
| 135 |
+
|
| 136 |
+
M8: Enhanced with robust deterministic tie-breaking for training reproducibility.
|
| 137 |
+
Uses scaled epsilon bias to prefer lower indices on score ties, ensuring
|
| 138 |
+
identical selection across runs with the same inputs.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
p_grp: Group probabilities [B,G,S_sel]
|
| 142 |
+
meta: Block metadata with selection ranges
|
| 143 |
+
n_top: Number of top blocks to select
|
| 144 |
+
t_token: Current token position (0-indexed)
|
| 145 |
+
force_init: Whether to force include block 0
|
| 146 |
+
force_local: Number of local blocks to force include
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Selected ranges [B,G,n_top,2] as [start,end) pairs
|
| 150 |
+
"""
|
| 151 |
+
# p_grp: [B,G,S_sel]
|
| 152 |
+
B, G, S_sel = p_grp.shape
|
| 153 |
+
device = p_grp.device
|
| 154 |
+
# Determine candidate blocks ≤ t
|
| 155 |
+
sel_starts = meta.sel_starts.to(device)
|
| 156 |
+
# mask future blocks
|
| 157 |
+
valid = sel_starts + meta.l_sel - 1 <= t_token
|
| 158 |
+
masked = p_grp.masked_fill(~valid.view(1, 1, -1), float("-inf"))
|
| 159 |
+
# force-includes set
|
| 160 |
+
forced_list = []
|
| 161 |
+
if force_init:
|
| 162 |
+
forced_list.append(torch.zeros((B, G), dtype=torch.int64, device=device))
|
| 163 |
+
if force_local > 0:
|
| 164 |
+
last_block = torch.clamp((torch.tensor(t_token, device=device) // meta.l_sel), min=0)
|
| 165 |
+
for i in range(force_local):
|
| 166 |
+
forced_list.append(torch.clamp(last_block - i, min=0).expand(B, G))
|
| 167 |
+
forced_idx = (
|
| 168 |
+
torch.stack(forced_list, dim=-1)
|
| 169 |
+
if forced_list
|
| 170 |
+
else torch.empty((B, G, 0), device=device, dtype=torch.int64)
|
| 171 |
+
)
|
| 172 |
+
# Exclude forced from top-k candidates by setting their scores to -inf
|
| 173 |
+
if forced_idx.numel() > 0:
|
| 174 |
+
forced_mask = torch.zeros_like(masked, dtype=torch.bool)
|
| 175 |
+
forced_mask.scatter_(-1, forced_idx, True)
|
| 176 |
+
masked = masked.masked_fill(forced_mask, float("-inf"))
|
| 177 |
+
# pick remaining to fill up to n_top
|
| 178 |
+
k_rest = torch.clamp(torch.tensor(n_top - forced_idx.shape[-1], device=device), min=0).item()
|
| 179 |
+
if k_rest > 0:
|
| 180 |
+
# M8: Deterministic tie-breaker - prefer lower indices for reproducible selection
|
| 181 |
+
# Use a tiny, fixed bias in float32 space to avoid overwhelming scores in low-precision
|
| 182 |
+
# dtypes (e.g., bf16/FP16). We perform ranking in float32 regardless of input dtype.
|
| 183 |
+
tie_break_scale = torch.tensor(1e-8, device=device, dtype=torch.float32)
|
| 184 |
+
base_idx = torch.arange(S_sel, device=device, dtype=torch.float32).view(1, 1, S_sel)
|
| 185 |
+
composite = masked.to(torch.float32) - (base_idx * tie_break_scale)
|
| 186 |
+
# Ensure deterministic topk with sorted=True for consistent ordering
|
| 187 |
+
k_actual = min(k_rest, S_sel)
|
| 188 |
+
_, top_idx = torch.topk(composite, k=k_actual, dim=-1, largest=True, sorted=True)
|
| 189 |
+
|
| 190 |
+
# M8: Assert tie-breaking worked - check for potential numerical issues
|
| 191 |
+
if torch.is_grad_enabled():
|
| 192 |
+
# Only check during training when gradients are enabled
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
orig_scores = torch.gather(masked, -1, top_idx).to(torch.float32)
|
| 195 |
+
if orig_scores.numel() > 1:
|
| 196 |
+
# Check if adjacent scores are suspiciously close (potential tie-break failure)
|
| 197 |
+
score_diffs = torch.diff(orig_scores, dim=-1)
|
| 198 |
+
very_close = torch.abs(score_diffs) < (float(tie_break_scale.item()) * 0.1)
|
| 199 |
+
if very_close.any():
|
| 200 |
+
from nsa.core.debug import log
|
| 201 |
+
|
| 202 |
+
log(
|
| 203 |
+
"warn.selection_tiebreak",
|
| 204 |
+
msg="Close scores detected in selection - potential tie-break instability",
|
| 205 |
+
min_diff=float(torch.abs(score_diffs).min().item()),
|
| 206 |
+
tie_break_scale=float(tie_break_scale),
|
| 207 |
+
)
|
| 208 |
+
sel_idx = torch.cat([forced_idx, top_idx], dim=-1)
|
| 209 |
+
else:
|
| 210 |
+
sel_idx = forced_idx
|
| 211 |
+
# sort selected indices ascending for consistent range merging
|
| 212 |
+
sel_idx = torch.sort(sel_idx, dim=-1).values
|
| 213 |
+
|
| 214 |
+
# M8: Optional determinism validation (skip if called from validation itself)
|
| 215 |
+
if not _skip_validation and os.getenv("NSA_VALIDATE_SELECTION_DETERMINISM", "0").lower() in (
|
| 216 |
+
"1",
|
| 217 |
+
"true",
|
| 218 |
+
"yes",
|
| 219 |
+
):
|
| 220 |
+
validate_selection_determinism(p_grp, meta, n_top, t_token)
|
| 221 |
+
# merge adjacent into contiguous ranges
|
| 222 |
+
ranges = []
|
| 223 |
+
for b in range(B):
|
| 224 |
+
bg = []
|
| 225 |
+
for g in range(G):
|
| 226 |
+
blocks = sel_starts[sel_idx[b, g]] # [k], sorted non-decreasing
|
| 227 |
+
# Deduplicate without extra sort (faster on GPU for small k)
|
| 228 |
+
blocks = torch.unique_consecutive(blocks)
|
| 229 |
+
if blocks.numel() == 0:
|
| 230 |
+
bg.append(torch.zeros((n_top, 2), dtype=torch.int32, device=device))
|
| 231 |
+
continue
|
| 232 |
+
cur_s = int(blocks[0].item())
|
| 233 |
+
cur_e = cur_s + meta.l_sel
|
| 234 |
+
merged: List[Tuple[int, int]] = []
|
| 235 |
+
for x in blocks[1:].tolist():
|
| 236 |
+
if x == cur_e: # adjacent
|
| 237 |
+
cur_e += meta.l_sel
|
| 238 |
+
else:
|
| 239 |
+
merged.append((cur_s, cur_e))
|
| 240 |
+
cur_s, cur_e = x, x + meta.l_sel
|
| 241 |
+
merged.append((cur_s, cur_e))
|
| 242 |
+
# pad/truncate to n_top
|
| 243 |
+
out = torch.zeros((n_top, 2), dtype=torch.int32, device=device)
|
| 244 |
+
for i, (s, e) in enumerate(merged[:n_top]):
|
| 245 |
+
e = min(e, t_token + 1)
|
| 246 |
+
out[i, 0] = s
|
| 247 |
+
out[i, 1] = e
|
| 248 |
+
bg.append(out)
|
| 249 |
+
ranges.append(torch.stack(bg, dim=0))
|
| 250 |
+
return torch.stack(ranges, dim=0) # [B,G,n_top,2]
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# ===== Batched selection (prefill fast path) =====
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def select_topn_ranges_batched(
|
| 257 |
+
p_grp_all: torch.Tensor, # [B,S,G,S_sel]
|
| 258 |
+
meta: BlockMeta,
|
| 259 |
+
n_top: int,
|
| 260 |
+
S: int,
|
| 261 |
+
force_init: bool = True,
|
| 262 |
+
force_local: int = 2,
|
| 263 |
+
) -> torch.Tensor: # [B,S,G,n_ranges,2]
|
| 264 |
+
"""
|
| 265 |
+
M8: Deterministic batched selection with enhanced tie-breaking:
|
| 266 |
+
- Mask future blocks per position t via block end ≤ t+1
|
| 267 |
+
- Force include block 0 and last k local blocks (dedup)
|
| 268 |
+
- Exclude forced from scored top‑k
|
| 269 |
+
- Robust deterministic tie‑break to lower index on equal scores
|
| 270 |
+
- Convert to merged contiguous [start,end) ranges clamped to ≤ t+1
|
| 271 |
+
- Validation hooks for training reproducibility
|
| 272 |
+
"""
|
| 273 |
+
B, S_q, G, S_sel = p_grp_all.shape
|
| 274 |
+
device = p_grp_all.device
|
| 275 |
+
|
| 276 |
+
sel_starts = meta.sel_starts.to(device)
|
| 277 |
+
sel_ends = sel_starts + meta.l_sel
|
| 278 |
+
tpos = torch.arange(S, device=device).view(S, 1)
|
| 279 |
+
valid = sel_ends.view(1, -1) <= (tpos + 1) # [S,S_sel]
|
| 280 |
+
disallowed = ~valid
|
| 281 |
+
masked = p_grp_all.masked_fill(disallowed.view(1, S, 1, S_sel), float("-inf"))
|
| 282 |
+
|
| 283 |
+
# Forced blocks (dedup across 0 and locals)
|
| 284 |
+
forced_list = []
|
| 285 |
+
if force_init:
|
| 286 |
+
forced_list.append(torch.zeros((B, S, G, 1), dtype=torch.long, device=device))
|
| 287 |
+
if force_local > 0:
|
| 288 |
+
tpos1 = torch.arange(S, device=device)
|
| 289 |
+
last_block = (tpos1 // meta.l_sel).clamp_min(0)
|
| 290 |
+
for k in range(force_local):
|
| 291 |
+
idx = (last_block - k).clamp_min(0).view(1, S, 1, 1).expand(B, S, G, 1)
|
| 292 |
+
forced_list.append(idx)
|
| 293 |
+
forced = (
|
| 294 |
+
torch.cat(forced_list, dim=-1)
|
| 295 |
+
if forced_list
|
| 296 |
+
else torch.empty((B, S, G, 0), dtype=torch.long, device=device)
|
| 297 |
+
)
|
| 298 |
+
if forced.numel() > 0:
|
| 299 |
+
# Ensure ascending per trailing dim then drop duplicates consecutively
|
| 300 |
+
forced = torch.sort(forced, dim=-1).values
|
| 301 |
+
forced = torch.unique_consecutive(forced, dim=-1)
|
| 302 |
+
|
| 303 |
+
if forced.numel() > 0:
|
| 304 |
+
forced_mask = torch.zeros_like(masked, dtype=torch.bool)
|
| 305 |
+
forced_mask.scatter_(-1, forced, True)
|
| 306 |
+
masked = masked.masked_fill(forced_mask, float("-inf"))
|
| 307 |
+
|
| 308 |
+
# Deterministic top‑k using composite key with tiny index bias
|
| 309 |
+
k_rest = max(0, n_top - forced.shape[-1])
|
| 310 |
+
if k_rest > 0:
|
| 311 |
+
# M8: Deterministic tie-breaker - prefer lower indices; rank in float32 to avoid
|
| 312 |
+
# overwhelming biases under low-precision dtypes.
|
| 313 |
+
tie_break_scale = torch.tensor(1e-8, device=device, dtype=torch.float32)
|
| 314 |
+
base_idx = (
|
| 315 |
+
torch.arange(S_sel, device=device, dtype=torch.float32)
|
| 316 |
+
.view(1, 1, 1, S_sel)
|
| 317 |
+
.expand(B, S, G, S_sel)
|
| 318 |
+
)
|
| 319 |
+
composite = masked.to(torch.float32) - (base_idx * tie_break_scale)
|
| 320 |
+
# Ensure deterministic topk with explicit sorted=True for batched path
|
| 321 |
+
k_actual = min(k_rest, S_sel)
|
| 322 |
+
_, top_idx = torch.topk(composite, k=k_actual, dim=-1, largest=True, sorted=True)
|
| 323 |
+
|
| 324 |
+
# M8: Optional validation for tie-breaking effectiveness in training
|
| 325 |
+
if torch.is_grad_enabled() and k_actual > 1:
|
| 326 |
+
with torch.no_grad():
|
| 327 |
+
orig_scores = torch.gather(masked, -1, top_idx).to(torch.float32)
|
| 328 |
+
# Check last dimension for potential tie-break issues
|
| 329 |
+
score_diffs = torch.diff(orig_scores, dim=-1)
|
| 330 |
+
very_close = torch.abs(score_diffs) < (float(tie_break_scale.item()) * 0.1)
|
| 331 |
+
if very_close.any():
|
| 332 |
+
from nsa.core.debug import log
|
| 333 |
+
|
| 334 |
+
log(
|
| 335 |
+
"warn.batched_selection_tiebreak",
|
| 336 |
+
msg="Close scores in batched selection - potential instability",
|
| 337 |
+
batch_close_count=int(very_close.sum().item()),
|
| 338 |
+
tie_break_scale=float(tie_break_scale),
|
| 339 |
+
)
|
| 340 |
+
selected = torch.cat([forced, top_idx], dim=-1)
|
| 341 |
+
else:
|
| 342 |
+
selected = forced[..., :n_top]
|
| 343 |
+
|
| 344 |
+
# Keep only valid (≤ t) indices; drop disallowed fill-ins
|
| 345 |
+
valid_full = valid.view(1, S, 1, S_sel).expand(B, S, G, S_sel)
|
| 346 |
+
is_valid_pick = torch.gather(valid_full, -1, selected)
|
| 347 |
+
# Replace invalid with -1 sentinel
|
| 348 |
+
selected = torch.where(is_valid_pick, selected, torch.full_like(selected, -1))
|
| 349 |
+
# Special-case: if requested n_top ≥ number of valid blocks at t, select exactly all valid blocks [0..t]
|
| 350 |
+
num_valid = valid.sum(dim=1) # [S]
|
| 351 |
+
# Build ascending [0..S_sel-1] to pick prefix per t
|
| 352 |
+
all_idx = torch.arange(S_sel, device=device).view(1, 1, 1, S_sel).expand(B, S, G, S_sel)
|
| 353 |
+
pick_mask = all_idx < num_valid.view(1, S, 1, 1)
|
| 354 |
+
if n_top >= S_sel:
|
| 355 |
+
selected = torch.where(pick_mask, all_idx, torch.full_like(all_idx, -1))
|
| 356 |
+
selected = torch.sort(selected, dim=-1).values
|
| 357 |
+
# Env-gated GPU range conversion (v2) to remove Python loops on hot path
|
| 358 |
+
use_v2 = os.getenv("NSA_SEL_RANGES_V2", "1").lower() in ("1", "true", "yes")
|
| 359 |
+
if use_v2:
|
| 360 |
+
ranges = convert_indices_to_ranges_batched_v2(selected, meta, S)
|
| 361 |
+
else:
|
| 362 |
+
ranges = convert_indices_to_ranges_batched(selected, meta, S)
|
| 363 |
+
return ranges
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def convert_indices_to_ranges_batched_dispatch(
|
| 367 |
+
indices: torch.Tensor,
|
| 368 |
+
meta: BlockMeta,
|
| 369 |
+
S: int,
|
| 370 |
+
) -> torch.Tensor:
|
| 371 |
+
"""
|
| 372 |
+
Dispatch helper mirroring production behavior: chooses v2 by default unless disabled.
|
| 373 |
+
Exposed for tests and tooling.
|
| 374 |
+
"""
|
| 375 |
+
use_v2 = os.getenv("NSA_SEL_RANGES_V2", "1").lower() in ("1", "true", "yes")
|
| 376 |
+
if use_v2:
|
| 377 |
+
return convert_indices_to_ranges_batched_v2(indices, meta, S)
|
| 378 |
+
return convert_indices_to_ranges_batched(indices, meta, S)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def convert_indices_to_ranges_batched(
|
| 382 |
+
indices: torch.Tensor, # [B,S,G,k]
|
| 383 |
+
meta: BlockMeta,
|
| 384 |
+
S: int,
|
| 385 |
+
) -> torch.Tensor: # [B,S,G,n_max,2]
|
| 386 |
+
B, S_q, G, k = indices.shape
|
| 387 |
+
device = indices.device
|
| 388 |
+
sel_starts = meta.sel_starts.to(device)
|
| 389 |
+
|
| 390 |
+
all_ranges = []
|
| 391 |
+
for b in range(B):
|
| 392 |
+
for t in range(S_q):
|
| 393 |
+
clamp_end = int(t) + 1
|
| 394 |
+
for g in range(G):
|
| 395 |
+
block_ids = [int(x) for x in indices[b, t, g].tolist() if int(x) >= 0]
|
| 396 |
+
spans = []
|
| 397 |
+
last_s, last_e = None, None
|
| 398 |
+
prev = None
|
| 399 |
+
for bid in block_ids:
|
| 400 |
+
# Skip invalid/out-of-range indices defensively
|
| 401 |
+
if bid < 0 or bid >= sel_starts.numel():
|
| 402 |
+
continue
|
| 403 |
+
if prev is not None and bid == prev:
|
| 404 |
+
continue
|
| 405 |
+
prev = bid
|
| 406 |
+
s0 = int(sel_starts[bid].item())
|
| 407 |
+
e0 = min(s0 + meta.l_sel, clamp_end)
|
| 408 |
+
if e0 <= s0:
|
| 409 |
+
continue
|
| 410 |
+
if last_s is None:
|
| 411 |
+
last_s, last_e = s0, e0
|
| 412 |
+
elif s0 == last_e:
|
| 413 |
+
last_e = e0
|
| 414 |
+
else:
|
| 415 |
+
spans.append((last_s, last_e))
|
| 416 |
+
last_s, last_e = s0, e0
|
| 417 |
+
if last_s is not None:
|
| 418 |
+
spans.append((last_s, last_e))
|
| 419 |
+
all_ranges.append(spans)
|
| 420 |
+
|
| 421 |
+
max_ranges = max((len(r) for r in all_ranges), default=0)
|
| 422 |
+
out = torch.zeros((B, S_q, G, max_ranges, 2), dtype=torch.int32, device=device)
|
| 423 |
+
idx = 0
|
| 424 |
+
for b in range(B):
|
| 425 |
+
for t in range(S_q):
|
| 426 |
+
for g in range(G):
|
| 427 |
+
spans = all_ranges[idx]
|
| 428 |
+
for i, (s0, e0) in enumerate(spans):
|
| 429 |
+
out[b, t, g, i, 0] = s0
|
| 430 |
+
out[b, t, g, i, 1] = e0
|
| 431 |
+
idx += 1
|
| 432 |
+
return out
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def convert_indices_to_ranges_batched_v2(
|
| 436 |
+
indices: torch.Tensor, # [B,S,G,k], sorted asc, -1 padded
|
| 437 |
+
meta: BlockMeta,
|
| 438 |
+
S: int,
|
| 439 |
+
) -> torch.Tensor: # [B,S,G,k,2] (padded with zero-length ranges)
|
| 440 |
+
"""
|
| 441 |
+
Vectorized GPU range conversion with no Python loops.
|
| 442 |
+
- Treat equal and +1 successive block ids as a single merged run.
|
| 443 |
+
- Map runs to token [start, end) using sel_starts and l_sel.
|
| 444 |
+
- Clamp end to t+1 per row to preserve causality.
|
| 445 |
+
- Output is padded to k runs per row; zero-length ranges are encoded as [0,0].
|
| 446 |
+
"""
|
| 447 |
+
# NVTX annotation support
|
| 448 |
+
_nvtx = os.getenv("NSA_NVTX", "0").lower() in ("1", "true", "yes")
|
| 449 |
+
if _nvtx:
|
| 450 |
+
try:
|
| 451 |
+
torch.cuda.nvtx.range_push("nsa.sel.ranges_v2")
|
| 452 |
+
except Exception:
|
| 453 |
+
_nvtx = False
|
| 454 |
+
|
| 455 |
+
device = indices.device
|
| 456 |
+
B, S_q, G, K = indices.shape
|
| 457 |
+
if K == 0:
|
| 458 |
+
return torch.zeros((B, S_q, G, 0, 2), dtype=torch.int32, device=device)
|
| 459 |
+
|
| 460 |
+
# Valid mask and prepared index tensor
|
| 461 |
+
if _nvtx:
|
| 462 |
+
try:
|
| 463 |
+
torch.cuda.nvtx.range_push("v2_run_detection")
|
| 464 |
+
except Exception:
|
| 465 |
+
pass
|
| 466 |
+
|
| 467 |
+
valid = indices.ge(0)
|
| 468 |
+
x = torch.where(valid, indices, torch.full_like(indices, -2)) # sentinel -2
|
| 469 |
+
|
| 470 |
+
# Identify run starts: first valid element or break in adjacency (including dedup collapse)
|
| 471 |
+
x_shift = torch.cat([torch.full_like(x[..., :1], -2), x[..., :-1]], dim=-1)
|
| 472 |
+
prev_valid = x_shift.ge(0)
|
| 473 |
+
diff = x - x_shift
|
| 474 |
+
adjacent_or_dup = (diff.eq(1) | diff.eq(0)) & prev_valid
|
| 475 |
+
run_start = valid & (~adjacent_or_dup | (~prev_valid))
|
| 476 |
+
|
| 477 |
+
if _nvtx:
|
| 478 |
+
try:
|
| 479 |
+
torch.cuda.nvtx.range_pop()
|
| 480 |
+
except Exception:
|
| 481 |
+
pass
|
| 482 |
+
|
| 483 |
+
# Row-local run ids [0..runs_per_row-1], -1 for invalid
|
| 484 |
+
run_id = run_start.to(torch.int32).cumsum(dim=-1) - 1
|
| 485 |
+
run_id = torch.where(valid, run_id, torch.full_like(run_id, -1))
|
| 486 |
+
|
| 487 |
+
# Number of runs per row and flattened row indexing
|
| 488 |
+
runs_per_row = run_start.sum(dim=-1, dtype=torch.int32) # [B,S,G]
|
| 489 |
+
N = B * S_q * G
|
| 490 |
+
runs_per_row_flat = runs_per_row.reshape(N)
|
| 491 |
+
|
| 492 |
+
# Build flattened per-run metadata
|
| 493 |
+
# Flatten last dim for selection
|
| 494 |
+
run_start_flat = run_start.reshape(-1, K)
|
| 495 |
+
x_flat = x.reshape(-1, K)
|
| 496 |
+
run_id_flat = run_id.reshape(-1, K)
|
| 497 |
+
|
| 498 |
+
# Indices (within last dim) where runs start per row
|
| 499 |
+
pos = torch.arange(K, device=device, dtype=torch.int32)
|
| 500 |
+
pos_flat = pos.view(1, K).expand(run_start_flat.shape[0], K)
|
| 501 |
+
start_pos_flat = pos_flat[run_start_flat]
|
| 502 |
+
# Corresponding block ids where runs start
|
| 503 |
+
start_blk_flat = x_flat[run_start_flat].to(torch.int32)
|
| 504 |
+
|
| 505 |
+
# Build unique global run ids by offsetting row-local run ids with row offsets
|
| 506 |
+
run_offsets = torch.cumsum(torch.nn.functional.pad(runs_per_row_flat, (1, 0)), dim=0)[
|
| 507 |
+
:-1
|
| 508 |
+
] # [N]
|
| 509 |
+
# Row index per element (0..N-1)
|
| 510 |
+
row_ids = torch.arange(N, device=device, dtype=torch.int32)
|
| 511 |
+
row_ids_per_elem = row_ids.view(N, 1).expand(N, K)
|
| 512 |
+
# Global run id per element; -1 for invalid
|
| 513 |
+
global_rid = torch.where(
|
| 514 |
+
run_id_flat.ge(0),
|
| 515 |
+
run_id_flat + run_offsets.view(N, 1),
|
| 516 |
+
torch.full_like(run_id_flat, -1),
|
| 517 |
+
)
|
| 518 |
+
global_rid_valid = global_rid[run_id_flat.ge(0)] # [total_valid_elems]
|
| 519 |
+
|
| 520 |
+
# For each global run, compute max block id in that run (end block)
|
| 521 |
+
if _nvtx:
|
| 522 |
+
try:
|
| 523 |
+
torch.cuda.nvtx.range_push("v2_scatter_reduce")
|
| 524 |
+
except Exception:
|
| 525 |
+
pass
|
| 526 |
+
|
| 527 |
+
total_runs = int(runs_per_row_flat.sum().item())
|
| 528 |
+
if total_runs == 0:
|
| 529 |
+
if _nvtx:
|
| 530 |
+
try:
|
| 531 |
+
torch.cuda.nvtx.range_pop()
|
| 532 |
+
except Exception:
|
| 533 |
+
pass
|
| 534 |
+
return torch.zeros((B, S_q, G, K, 2), dtype=torch.int32, device=device)
|
| 535 |
+
max_blk = torch.full((total_runs,), -2, dtype=torch.int32, device=device)
|
| 536 |
+
# Values to reduce are block ids for valid elements
|
| 537 |
+
blk_vals = x_flat[run_id_flat.ge(0)].to(torch.int32)
|
| 538 |
+
max_blk.scatter_reduce_(
|
| 539 |
+
0, global_rid_valid.to(torch.int64), blk_vals, reduce="amax", include_self=False
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
if _nvtx:
|
| 543 |
+
try:
|
| 544 |
+
torch.cuda.nvtx.range_pop()
|
| 545 |
+
except Exception:
|
| 546 |
+
pass
|
| 547 |
+
|
| 548 |
+
# Start block ids per run, collected in row order
|
| 549 |
+
start_blk_per_run = start_blk_flat # length == total_runs
|
| 550 |
+
|
| 551 |
+
# Map block ids to token starts/ends (guard invalid/out-of-range)
|
| 552 |
+
sel_starts = meta.sel_starts.to(device=device, dtype=torch.int32)
|
| 553 |
+
S_sel = int(sel_starts.numel())
|
| 554 |
+
l_sel = int(meta.l_sel)
|
| 555 |
+
valid_runs = (
|
| 556 |
+
(start_blk_per_run >= 0)
|
| 557 |
+
& (start_blk_per_run < S_sel)
|
| 558 |
+
& (max_blk >= 0)
|
| 559 |
+
& (max_blk < S_sel)
|
| 560 |
+
)
|
| 561 |
+
# Default zeros; fill only valid runs
|
| 562 |
+
start_tok_flat = torch.zeros_like(start_blk_per_run, dtype=torch.int32, device=device)
|
| 563 |
+
end_tok_flat = torch.zeros_like(max_blk, dtype=torch.int32, device=device)
|
| 564 |
+
if valid_runs.any():
|
| 565 |
+
start_tok_flat[valid_runs] = sel_starts[start_blk_per_run[valid_runs]]
|
| 566 |
+
end_tok_flat[valid_runs] = sel_starts[max_blk[valid_runs]] + l_sel
|
| 567 |
+
|
| 568 |
+
# Clamp end to t+1 per row (only meaningful for valid runs)
|
| 569 |
+
# Row t positions: [S] repeated over B,G
|
| 570 |
+
tpos = torch.arange(S, device=device, dtype=torch.int32)
|
| 571 |
+
t_rows = tpos.view(1, S, 1).expand(B, S, G).reshape(N) # [N]
|
| 572 |
+
# t per run: repeat per row by runs_per_row
|
| 573 |
+
t_per_run = torch.repeat_interleave(t_rows, runs_per_row_flat)
|
| 574 |
+
end_tok_flat = torch.minimum(end_tok_flat, (t_per_run + 1))
|
| 575 |
+
|
| 576 |
+
# Prepare output [B,S,G,K,2], fill zeros then scatter first runs_per_row entries per row
|
| 577 |
+
out = torch.zeros((B, S_q, G, K, 2), dtype=torch.int32, device=device)
|
| 578 |
+
# Positions within row to write (0..K-1): take row-local run_id at run starts
|
| 579 |
+
run_id_at_starts = (run_id.reshape(-1, K))[run_start_flat]
|
| 580 |
+
# Compute base index in flattened out for each run write
|
| 581 |
+
# Build linear indices for advanced indexing
|
| 582 |
+
# Map flat run order back to (row, pos)
|
| 583 |
+
row_of_run = torch.repeat_interleave(row_ids, runs_per_row_flat)
|
| 584 |
+
pos_in_row = run_id_at_starts # 0..runs_per_row[row]-1
|
| 585 |
+
b = (row_of_run // (S_q * G)).to(torch.int64)
|
| 586 |
+
rem = row_of_run % (S_q * G)
|
| 587 |
+
t = (rem // G).to(torch.int64)
|
| 588 |
+
g = (rem % G).to(torch.int64)
|
| 589 |
+
p = pos_in_row.to(torch.int64)
|
| 590 |
+
# Scatter only valid runs
|
| 591 |
+
if valid_runs.any():
|
| 592 |
+
vr = valid_runs.to(torch.bool)
|
| 593 |
+
b_v = b[vr]
|
| 594 |
+
t_v = t[vr]
|
| 595 |
+
g_v = g[vr]
|
| 596 |
+
p_v = p[vr]
|
| 597 |
+
out[b_v, t_v, g_v, p_v, 0] = start_tok_flat[vr].to(torch.int32)
|
| 598 |
+
out[b_v, t_v, g_v, p_v, 1] = end_tok_flat[vr].to(torch.int32)
|
| 599 |
+
|
| 600 |
+
if _nvtx:
|
| 601 |
+
try:
|
| 602 |
+
torch.cuda.nvtx.range_pop()
|
| 603 |
+
except Exception:
|
| 604 |
+
pass
|
| 605 |
+
|
| 606 |
+
return out
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def map_pcmp_to_pslc_slow_path(p_cmp_all: torch.Tensor, meta: BlockMeta) -> torch.Tensor:
|
| 610 |
+
"""
|
| 611 |
+
M8: Eq.9 slow path verifier - explicit mathematical computation.
|
| 612 |
+
|
| 613 |
+
This function implements the exact mathematical definition by using the
|
| 614 |
+
CSR mapping directly instead of recomputing overlaps. This ensures it
|
| 615 |
+
matches the fast path exactly.
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
p_cmp_all: [B,S,G,h,S_cmp] compressed probabilities
|
| 619 |
+
meta: Block metadata with overlap mapping
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
p_slc_all: [B,S,G,h,S_sel] selection probabilities
|
| 623 |
+
"""
|
| 624 |
+
B, S, G, h, S_cmp = p_cmp_all.shape
|
| 625 |
+
device = p_cmp_all.device
|
| 626 |
+
S_sel = meta.sel_starts.numel()
|
| 627 |
+
|
| 628 |
+
if S_cmp == 0:
|
| 629 |
+
return torch.zeros((B, S, G, h, S_sel), device=device, dtype=p_cmp_all.dtype)
|
| 630 |
+
|
| 631 |
+
# Use CSR mapping directly (same as fast path but with explicit loops)
|
| 632 |
+
p_slc_all = torch.zeros((B, S, G, h, S_sel), device=device, dtype=p_cmp_all.dtype)
|
| 633 |
+
|
| 634 |
+
indptr = meta.M_csl_indptr.to(device)
|
| 635 |
+
indices = meta.M_csl_indices.to(device)
|
| 636 |
+
values = meta.M_csl_values.to(device, dtype=p_cmp_all.dtype)
|
| 637 |
+
|
| 638 |
+
# For each compressed block (CSR row)
|
| 639 |
+
for cmp_i in range(min(S_cmp, len(indptr) - 1)):
|
| 640 |
+
start = int(indptr[cmp_i].item())
|
| 641 |
+
end = int(indptr[cmp_i + 1].item())
|
| 642 |
+
|
| 643 |
+
if start == end:
|
| 644 |
+
continue
|
| 645 |
+
|
| 646 |
+
# Get the selection blocks this compressed block contributes to
|
| 647 |
+
sel_cols = indices[start:end]
|
| 648 |
+
weights = values[start:end]
|
| 649 |
+
|
| 650 |
+
# Add weighted contribution to each selection block
|
| 651 |
+
for j, (sel_idx, weight) in enumerate(zip(sel_cols, weights)):
|
| 652 |
+
sel_idx = int(sel_idx.item())
|
| 653 |
+
if sel_idx < S_sel:
|
| 654 |
+
p_slc_all[..., sel_idx] += p_cmp_all[..., cmp_i] * float(weight.item())
|
| 655 |
+
|
| 656 |
+
return p_slc_all
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
def verify_mapping_equivalence(
|
| 660 |
+
p_cmp_all: torch.Tensor, meta: BlockMeta, rtol: float = 1e-5, atol: float = 1e-8
|
| 661 |
+
) -> tuple[bool, dict]:
|
| 662 |
+
"""
|
| 663 |
+
M8: Verify fast COO path matches slow mathematical path (Eq.9 verification).
|
| 664 |
+
|
| 665 |
+
Args:
|
| 666 |
+
p_cmp_all: Compressed probabilities to test
|
| 667 |
+
meta: Block metadata
|
| 668 |
+
rtol: Relative tolerance for comparison
|
| 669 |
+
atol: Absolute tolerance for comparison
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
(is_equivalent, details): True if paths match, plus diagnostic info
|
| 673 |
+
"""
|
| 674 |
+
# Only run verification if explicitly requested via env flag
|
| 675 |
+
if os.getenv("NSA_VERIFY_EQ9_MAPPING", "0").lower() not in ("1", "true", "yes"):
|
| 676 |
+
return True, {"status": "skipped", "reason": "NSA_VERIFY_EQ9_MAPPING not set"}
|
| 677 |
+
|
| 678 |
+
with torch.no_grad():
|
| 679 |
+
# Compute both paths
|
| 680 |
+
fast_result = map_pcmp_to_pslc_batched(p_cmp_all, meta)
|
| 681 |
+
slow_result = map_pcmp_to_pslc_slow_path(p_cmp_all, meta)
|
| 682 |
+
|
| 683 |
+
# Compare results
|
| 684 |
+
is_close = torch.allclose(fast_result, slow_result, rtol=rtol, atol=atol)
|
| 685 |
+
|
| 686 |
+
# Compute diagnostic metrics
|
| 687 |
+
abs_diff = (fast_result - slow_result).abs()
|
| 688 |
+
max_abs_diff = abs_diff.max().item()
|
| 689 |
+
mean_abs_diff = abs_diff.mean().item()
|
| 690 |
+
rel_diff = abs_diff / (slow_result.abs() + atol)
|
| 691 |
+
max_rel_diff = rel_diff.max().item()
|
| 692 |
+
|
| 693 |
+
details = {
|
| 694 |
+
"status": "verified" if is_close else "mismatch",
|
| 695 |
+
"max_abs_diff": max_abs_diff,
|
| 696 |
+
"mean_abs_diff": mean_abs_diff,
|
| 697 |
+
"max_rel_diff": max_rel_diff,
|
| 698 |
+
"shape": list(p_cmp_all.shape),
|
| 699 |
+
"rtol": rtol,
|
| 700 |
+
"atol": atol,
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
if not is_close:
|
| 704 |
+
from nsa.core.debug import log
|
| 705 |
+
|
| 706 |
+
log(
|
| 707 |
+
"error.eq9_mapping_mismatch",
|
| 708 |
+
msg="Fast COO path does not match slow mathematical path",
|
| 709 |
+
**details,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
return is_close, details
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def validate_selection_determinism(
|
| 716 |
+
p_grp: torch.Tensor, meta: BlockMeta, n_top: int, t_token: int, num_trials: int = 5
|
| 717 |
+
) -> bool:
|
| 718 |
+
"""Validate that selection is deterministic by running multiple times.
|
| 719 |
+
|
| 720 |
+
Args:
|
| 721 |
+
p_grp: Group probabilities [B,G,S_sel]
|
| 722 |
+
meta: Block metadata
|
| 723 |
+
n_top: Number of top blocks to select
|
| 724 |
+
t_token: Current token position
|
| 725 |
+
num_trials: Number of trials to test determinism
|
| 726 |
+
|
| 727 |
+
Returns:
|
| 728 |
+
True if all trials produce identical results
|
| 729 |
+
"""
|
| 730 |
+
# Only run validation if explicitly requested via env flag
|
| 731 |
+
if os.getenv("NSA_VALIDATE_SELECTION_DETERMINISM", "0").lower() not in ("1", "true", "yes"):
|
| 732 |
+
return True
|
| 733 |
+
|
| 734 |
+
if p_grp.requires_grad:
|
| 735 |
+
# Don't validate during training to avoid affecting gradients
|
| 736 |
+
return True
|
| 737 |
+
|
| 738 |
+
with torch.no_grad():
|
| 739 |
+
results = []
|
| 740 |
+
for trial in range(num_trials):
|
| 741 |
+
ranges = select_topn_ranges(
|
| 742 |
+
p_grp.clone(), meta, n_top, t_token, True, 2, _skip_validation=True
|
| 743 |
+
)
|
| 744 |
+
results.append(ranges.clone())
|
| 745 |
+
|
| 746 |
+
# Check if all results are identical
|
| 747 |
+
for i in range(1, num_trials):
|
| 748 |
+
if not torch.equal(results[0], results[i]):
|
| 749 |
+
from nsa.core.debug import log
|
| 750 |
+
|
| 751 |
+
log(
|
| 752 |
+
"error.selection_nondeterministic",
|
| 753 |
+
msg=f"Selection non-deterministic: trial 0 != trial {i}",
|
| 754 |
+
trial_0_shape=list(results[0].shape),
|
| 755 |
+
trial_i_shape=list(results[i].shape),
|
| 756 |
+
)
|
| 757 |
+
return False
|
| 758 |
+
|
| 759 |
+
return True
|
nsa/data_pipeline.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
#!/usr/bin/env python3
|
| 3 |
+
"""Data pipeline utilities for streaming and local datasets.
|
| 4 |
+
|
| 5 |
+
Provides a FineWeb-Edu IterableDataset and simple local JSONL/TXT loaders.
|
| 6 |
+
This module is optional; scripts/train_showcase.py currently uses a simpler
|
| 7 |
+
loader in scripts/datasets. Migrate incrementally as needed.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Callable, Iterable, Iterator, List, Optional
|
| 16 |
+
|
| 17 |
+
Tokenizer = Callable[[str], List[int]]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class Shard:
|
| 22 |
+
mod: int = 1
|
| 23 |
+
rem: int = 0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def fineweb_stream_batches(
|
| 27 |
+
encode: Tokenizer,
|
| 28 |
+
seq_len: int,
|
| 29 |
+
batch_size: int,
|
| 30 |
+
shard: Shard = Shard(),
|
| 31 |
+
report_docs: int = 1000,
|
| 32 |
+
) -> Iterator[List[List[int]]]:
|
| 33 |
+
try:
|
| 34 |
+
from datasets import Features, Value, load_dataset # type: ignore
|
| 35 |
+
except Exception as e:
|
| 36 |
+
raise RuntimeError("datasets package required. Install with: pip install datasets") from e
|
| 37 |
+
|
| 38 |
+
features = Features(
|
| 39 |
+
{
|
| 40 |
+
"text": Value("string"),
|
| 41 |
+
"id": Value("string"),
|
| 42 |
+
"dump": Value("string"),
|
| 43 |
+
"url": Value("string"),
|
| 44 |
+
"file_path": Value("string"),
|
| 45 |
+
"language": Value("string"),
|
| 46 |
+
"language_score": Value("float64"),
|
| 47 |
+
"token_count": Value("int64"),
|
| 48 |
+
"score": Value("float64"),
|
| 49 |
+
"int_score": Value("int64"),
|
| 50 |
+
}
|
| 51 |
+
)
|
| 52 |
+
ds = load_dataset("HuggingFaceFW/fineweb-edu", split="train", streaming=True, features=features)
|
| 53 |
+
buf: List[int] = []
|
| 54 |
+
batch: List[List[int]] = []
|
| 55 |
+
seen = 0
|
| 56 |
+
import time as _t
|
| 57 |
+
|
| 58 |
+
t0 = _t.time()
|
| 59 |
+
last = t0
|
| 60 |
+
for ex in ds:
|
| 61 |
+
if seen % shard.mod != shard.rem:
|
| 62 |
+
seen += 1
|
| 63 |
+
continue
|
| 64 |
+
seen += 1
|
| 65 |
+
if report_docs and seen % report_docs == 0:
|
| 66 |
+
dt = _t.time() - last
|
| 67 |
+
print(f"[fwe] seen_docs={seen} dt={dt:.1f}s buf={len(buf)}", flush=True)
|
| 68 |
+
last = _t.time()
|
| 69 |
+
text = ex.get("text") or ""
|
| 70 |
+
if not text:
|
| 71 |
+
continue
|
| 72 |
+
toks = encode(text)
|
| 73 |
+
if not toks:
|
| 74 |
+
continue
|
| 75 |
+
buf.extend(toks)
|
| 76 |
+
while len(buf) >= seq_len:
|
| 77 |
+
seq = buf[:seq_len]
|
| 78 |
+
buf = buf[seq_len:]
|
| 79 |
+
batch.append(seq)
|
| 80 |
+
if len(batch) >= batch_size:
|
| 81 |
+
yield batch[:batch_size]
|
| 82 |
+
batch = batch[batch_size:]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def fineweb_stream_batches_batched(
|
| 86 |
+
encode_batch: Callable[[List[str]], List[List[int]]],
|
| 87 |
+
seq_len: int,
|
| 88 |
+
batch_size: int,
|
| 89 |
+
shard: Shard = Shard(),
|
| 90 |
+
report_docs: int = 1000,
|
| 91 |
+
doc_batch: int = 64,
|
| 92 |
+
) -> Iterator[List[List[int]]]:
|
| 93 |
+
"""Streaming FineWeb‑Edu with batched tokenization and fixed-length packing.
|
| 94 |
+
|
| 95 |
+
- encode_batch: function mapping a list of texts -> list of token id lists
|
| 96 |
+
- Packs contiguous tokens from a rolling buffer into fixed seq_len examples
|
| 97 |
+
- Yields Python lists of shape [batch_size][seq_len]
|
| 98 |
+
"""
|
| 99 |
+
try:
|
| 100 |
+
from datasets import load_dataset, Features, Value # type: ignore
|
| 101 |
+
except Exception as e:
|
| 102 |
+
raise RuntimeError("datasets package required. Install with: pip install datasets") from e
|
| 103 |
+
|
| 104 |
+
features = Features(
|
| 105 |
+
{
|
| 106 |
+
"text": Value("string"),
|
| 107 |
+
"id": Value("string"),
|
| 108 |
+
"dump": Value("string"),
|
| 109 |
+
"url": Value("string"),
|
| 110 |
+
"file_path": Value("string"),
|
| 111 |
+
"language": Value("string"),
|
| 112 |
+
"language_score": Value("float64"),
|
| 113 |
+
"token_count": Value("int64"),
|
| 114 |
+
"score": Value("float64"),
|
| 115 |
+
"int_score": Value("int64"),
|
| 116 |
+
}
|
| 117 |
+
)
|
| 118 |
+
ds = load_dataset("HuggingFaceFW/fineweb-edu", split="train", streaming=True, features=features)
|
| 119 |
+
|
| 120 |
+
buf: List[int] = []
|
| 121 |
+
batch: List[List[int]] = []
|
| 122 |
+
seen = 0
|
| 123 |
+
acc_texts: List[str] = []
|
| 124 |
+
import time as _t
|
| 125 |
+
last = _t.time()
|
| 126 |
+
for ex in ds:
|
| 127 |
+
if seen % shard.mod != shard.rem:
|
| 128 |
+
seen += 1
|
| 129 |
+
continue
|
| 130 |
+
seen += 1
|
| 131 |
+
if report_docs and seen % report_docs == 0:
|
| 132 |
+
dt = _t.time() - last
|
| 133 |
+
print(f"[fwe] (batched) seen_docs={seen} dt={dt:.1f}s buf={len(buf)} acc_texts={len(acc_texts)}", flush=True)
|
| 134 |
+
last = _t.time()
|
| 135 |
+
text = ex.get("text") or ""
|
| 136 |
+
if not text:
|
| 137 |
+
continue
|
| 138 |
+
acc_texts.append(text)
|
| 139 |
+
if len(acc_texts) < max(1, int(doc_batch)):
|
| 140 |
+
continue
|
| 141 |
+
# Batched tokenize
|
| 142 |
+
try:
|
| 143 |
+
toks_list = encode_batch(acc_texts)
|
| 144 |
+
except Exception:
|
| 145 |
+
# Fallback to per-doc encode if batch path fails
|
| 146 |
+
toks_list = []
|
| 147 |
+
for t in acc_texts:
|
| 148 |
+
try:
|
| 149 |
+
toks_list.append(encode_batch([t])[0])
|
| 150 |
+
except Exception:
|
| 151 |
+
toks_list.append([])
|
| 152 |
+
acc_texts.clear()
|
| 153 |
+
# Fill rolling buffer and output fixed-length sequences
|
| 154 |
+
for toks in toks_list:
|
| 155 |
+
if not toks:
|
| 156 |
+
continue
|
| 157 |
+
buf.extend(toks)
|
| 158 |
+
while len(buf) >= seq_len:
|
| 159 |
+
seq = buf[:seq_len]
|
| 160 |
+
buf = buf[seq_len:]
|
| 161 |
+
batch.append(seq)
|
| 162 |
+
if len(batch) >= batch_size:
|
| 163 |
+
yield batch[:batch_size]
|
| 164 |
+
batch = batch[batch_size:]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def local_jsonl_or_txt_batches(
|
| 168 |
+
path: str,
|
| 169 |
+
encode: Tokenizer,
|
| 170 |
+
seq_len: int,
|
| 171 |
+
batch_size: int,
|
| 172 |
+
) -> Iterator[List[List[int]]]:
|
| 173 |
+
is_jsonl = path.endswith(".jsonl")
|
| 174 |
+
buf: List[int] = []
|
| 175 |
+
batch: List[List[int]] = []
|
| 176 |
+
with open(path, encoding="utf-8", errors="ignore") as fh:
|
| 177 |
+
for line in fh:
|
| 178 |
+
line = line.strip()
|
| 179 |
+
if not line:
|
| 180 |
+
continue
|
| 181 |
+
text = line
|
| 182 |
+
if is_jsonl:
|
| 183 |
+
try:
|
| 184 |
+
obj = json.loads(line)
|
| 185 |
+
if isinstance(obj, dict) and isinstance(obj.get("text"), str):
|
| 186 |
+
text = obj["text"]
|
| 187 |
+
except Exception:
|
| 188 |
+
pass
|
| 189 |
+
toks = encode(text)
|
| 190 |
+
if not toks:
|
| 191 |
+
continue
|
| 192 |
+
buf.extend(toks)
|
| 193 |
+
while len(buf) >= seq_len:
|
| 194 |
+
seq = buf[:seq_len]
|
| 195 |
+
buf = buf[seq_len:]
|
| 196 |
+
batch.append(seq)
|
| 197 |
+
if len(batch) >= batch_size:
|
| 198 |
+
yield batch[:batch_size]
|
| 199 |
+
batch = batch[batch_size:]
|
nsa/kernels/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
nsa/kernels/flash_wrappers.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from nsa.core.debug import log
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _env_bool(name: str, default: bool = False) -> bool:
|
| 10 |
+
v = str(name and __import__("os").getenv(name, "1" if default else "0")).lower()
|
| 11 |
+
return v in ("1", "true", "yes", "on")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def flash_attn_version() -> str | None:
|
| 15 |
+
"""Return flash-attn version string if importable, else None."""
|
| 16 |
+
try:
|
| 17 |
+
import flash_attn as _fa # type: ignore
|
| 18 |
+
|
| 19 |
+
return getattr(_fa, "__version__", None)
|
| 20 |
+
except Exception:
|
| 21 |
+
return None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def is_flash_available() -> bool:
|
| 25 |
+
"""Return True if flash-attn dense API is importable."""
|
| 26 |
+
try:
|
| 27 |
+
from flash_attn import flash_attn_func # type: ignore
|
| 28 |
+
|
| 29 |
+
_ = flash_attn_func # silence linter
|
| 30 |
+
return True
|
| 31 |
+
except Exception:
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def is_flash_varlen_available() -> bool:
|
| 36 |
+
"""Return True if a varlen API is importable (either QKV or KV-packed)."""
|
| 37 |
+
try:
|
| 38 |
+
from flash_attn import flash_attn_varlen_func # type: ignore
|
| 39 |
+
|
| 40 |
+
_ = flash_attn_varlen_func
|
| 41 |
+
return True
|
| 42 |
+
except Exception:
|
| 43 |
+
try:
|
| 44 |
+
from flash_attn import flash_attn_varlen_kvpacked_func # type: ignore
|
| 45 |
+
|
| 46 |
+
_ = flash_attn_varlen_kvpacked_func
|
| 47 |
+
return True
|
| 48 |
+
except Exception:
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def fa2_supported_verbose(
|
| 53 |
+
device: torch.device, dtype: torch.dtype, head_dim: int
|
| 54 |
+
) -> tuple[bool, str]:
|
| 55 |
+
"""
|
| 56 |
+
Conservative capability probe with a reason string for logging.
|
| 57 |
+
We do not hard-fail on dtype, relying on try/except at call sites.
|
| 58 |
+
"""
|
| 59 |
+
if device.type != "cuda":
|
| 60 |
+
return False, "device_not_cuda"
|
| 61 |
+
if head_dim % 8 != 0:
|
| 62 |
+
return False, "head_dim_not_multiple_of_8"
|
| 63 |
+
if not (is_flash_varlen_available() or is_flash_available()):
|
| 64 |
+
return False, "flash_attn_not_importable"
|
| 65 |
+
# Optional version floor (best-effort)
|
| 66 |
+
ver = flash_attn_version()
|
| 67 |
+
if ver is None:
|
| 68 |
+
# Unknown version; still allow
|
| 69 |
+
return True, "ok"
|
| 70 |
+
# Allow all known versions; attach for logs
|
| 71 |
+
return True, f"ok_v{ver}"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def fa2_supported(device: torch.device, dtype: torch.dtype, head_dim: int) -> bool:
|
| 75 |
+
ok, _ = fa2_supported_verbose(device, dtype, head_dim)
|
| 76 |
+
return ok
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def attention_bgh(
|
| 80 |
+
Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, causal: bool = True
|
| 81 |
+
) -> torch.Tensor:
|
| 82 |
+
"""
|
| 83 |
+
Q: [B,G,h,Dk], K/V: [B,G,S,D*] -> out [B,G,h,Dv]
|
| 84 |
+
Prefer flash-attn if available; fallback to SDPA.
|
| 85 |
+
"""
|
| 86 |
+
B, G, h, Dk = Q.shape
|
| 87 |
+
S = K.shape[2]
|
| 88 |
+
# Try FA-2 dense path first
|
| 89 |
+
if is_flash_available():
|
| 90 |
+
try:
|
| 91 |
+
from flash_attn import flash_attn_func # type: ignore
|
| 92 |
+
|
| 93 |
+
# Reshape without materializing copies
|
| 94 |
+
q = Q.transpose(1, 2).reshape(B, G * h, 1, Dk) # [B,G*h,1,Dk]
|
| 95 |
+
k = K.unsqueeze(2).expand(B, G, h, S, Dk).reshape(B, G * h, S, Dk) # [B,G*h,S,Dk]
|
| 96 |
+
v = (
|
| 97 |
+
V.unsqueeze(2).expand(B, G, h, S, V.shape[-1]).reshape(B, G * h, S, V.shape[-1])
|
| 98 |
+
) # [B,G*h,S,Dv]
|
| 99 |
+
if _env_bool("NSA_DEBUG_TIMING"):
|
| 100 |
+
log("fa2.bgh.path", path="fa2.dense", B=B, G=G, h=h, S=S, Dk=Dk)
|
| 101 |
+
o = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=causal)
|
| 102 |
+
o = o.reshape(B, G, h, -1)
|
| 103 |
+
if not torch.isfinite(o).all():
|
| 104 |
+
log("warn.flash_bgh_nonfinite", path="fa2.dense")
|
| 105 |
+
return torch.nan_to_num(o, nan=0.0)
|
| 106 |
+
except Exception:
|
| 107 |
+
pass
|
| 108 |
+
# SDPA fallback
|
| 109 |
+
if _env_bool("NSA_DEBUG_TIMING"):
|
| 110 |
+
log("fa2.bgh.path", path="sdpa", B=B, G=G, h=h, S=S, Dk=Dk)
|
| 111 |
+
# Expand heads via view/expand to avoid materializing copies
|
| 112 |
+
q2 = Q.reshape(B * G * h, 1, Dk).contiguous()
|
| 113 |
+
k2 = K.unsqueeze(2).expand(B, G, h, S, Dk).reshape(B * G * h, S, Dk).contiguous()
|
| 114 |
+
v2 = (
|
| 115 |
+
V.unsqueeze(2)
|
| 116 |
+
.expand(B, G, h, S, V.shape[-1])
|
| 117 |
+
.reshape(B * G * h, S, V.shape[-1])
|
| 118 |
+
.contiguous()
|
| 119 |
+
)
|
| 120 |
+
attn = F.scaled_dot_product_attention(q2, k2, v2, is_causal=causal)
|
| 121 |
+
o = attn.squeeze(1).reshape(B, G, h, -1)
|
| 122 |
+
return torch.nan_to_num(o, nan=0.0)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def attention_fa2_dense_batch(
|
| 126 |
+
q: torch.Tensor,
|
| 127 |
+
k: torch.Tensor,
|
| 128 |
+
v: torch.Tensor,
|
| 129 |
+
*,
|
| 130 |
+
causal: bool,
|
| 131 |
+
) -> torch.Tensor:
|
| 132 |
+
"""
|
| 133 |
+
Best-effort dense FA-2 call for a batch of independent rows.
|
| 134 |
+
Shapes:
|
| 135 |
+
- q: [N, Tq, h, D]
|
| 136 |
+
- k: [N, Tk, h, D]
|
| 137 |
+
- v: [N, Tk, h, Dv]
|
| 138 |
+
Returns: o [N, Tq, h, Dv]
|
| 139 |
+
Falls back to SDPA if flash-attn unavailable.
|
| 140 |
+
"""
|
| 141 |
+
# Ensure contiguous tensors for FA-2
|
| 142 |
+
q = q.contiguous()
|
| 143 |
+
k = k.contiguous()
|
| 144 |
+
v = v.contiguous()
|
| 145 |
+
try:
|
| 146 |
+
from flash_attn import flash_attn_func # type: ignore
|
| 147 |
+
|
| 148 |
+
if _env_bool("NSA_DEBUG_TIMING"):
|
| 149 |
+
log(
|
| 150 |
+
"fa2.batch.path",
|
| 151 |
+
path="fa2.dense",
|
| 152 |
+
N=int(q.shape[0]),
|
| 153 |
+
Tq=int(q.shape[1]),
|
| 154 |
+
Tk=int(k.shape[1]),
|
| 155 |
+
)
|
| 156 |
+
return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=causal)
|
| 157 |
+
except Exception:
|
| 158 |
+
# SDPA fallback per row
|
| 159 |
+
N, Tq, h, D = q.shape
|
| 160 |
+
Tk = k.shape[1]
|
| 161 |
+
Dv = v.shape[-1]
|
| 162 |
+
if _env_bool("NSA_DEBUG_TIMING"):
|
| 163 |
+
log("fa2.batch.path", path="sdpa", N=int(N), Tq=int(Tq), Tk=int(Tk))
|
| 164 |
+
q2 = q.reshape(N * h, Tq, D)
|
| 165 |
+
k2 = k.reshape(N * h, Tk, D)
|
| 166 |
+
v2 = v.reshape(N * h, Tk, Dv)
|
| 167 |
+
out = F.scaled_dot_product_attention(q2, k2, v2, is_causal=causal)
|
| 168 |
+
return out.reshape(N, h, Tq, Dv).permute(0, 2, 1, 3).contiguous()
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def attention_fa2_varlen(
|
| 172 |
+
q: torch.Tensor,
|
| 173 |
+
k: torch.Tensor,
|
| 174 |
+
v: torch.Tensor,
|
| 175 |
+
cu_seqlens_q: torch.Tensor,
|
| 176 |
+
cu_seqlens_k: torch.Tensor,
|
| 177 |
+
max_seqlen_q: int,
|
| 178 |
+
max_seqlen_k: int,
|
| 179 |
+
*,
|
| 180 |
+
causal: bool,
|
| 181 |
+
) -> torch.Tensor:
|
| 182 |
+
"""
|
| 183 |
+
Best-effort varlen FA-2 call with separate Q/K/V packing.
|
| 184 |
+
Shapes:
|
| 185 |
+
- q: [total_q, h, D], k: [total_k, h, D], v: [total_k, h, Dv]
|
| 186 |
+
- cu_seqlens_*: int32 [N+1]
|
| 187 |
+
Returns: [total_q, h, Dv] packed output.
|
| 188 |
+
Falls back to dense batching by padding per bucket if varlen API unavailable.
|
| 189 |
+
"""
|
| 190 |
+
# Ensure contiguous tensors for FA-2
|
| 191 |
+
q = q.contiguous()
|
| 192 |
+
k = k.contiguous()
|
| 193 |
+
v = v.contiguous()
|
| 194 |
+
try:
|
| 195 |
+
from flash_attn import flash_attn_varlen_func # type: ignore
|
| 196 |
+
|
| 197 |
+
return flash_attn_varlen_func(
|
| 198 |
+
q,
|
| 199 |
+
k,
|
| 200 |
+
v,
|
| 201 |
+
cu_seqlens_q,
|
| 202 |
+
cu_seqlens_k,
|
| 203 |
+
max_seqlen_q,
|
| 204 |
+
max_seqlen_k,
|
| 205 |
+
dropout_p=0.0,
|
| 206 |
+
softmax_scale=None,
|
| 207 |
+
causal=causal,
|
| 208 |
+
)
|
| 209 |
+
except Exception:
|
| 210 |
+
# Try KV-packed API variant
|
| 211 |
+
try:
|
| 212 |
+
from flash_attn import flash_attn_varlen_kvpacked_func # type: ignore
|
| 213 |
+
|
| 214 |
+
# Build KV packed as [total_k, 2, h, D]
|
| 215 |
+
kv_packed = torch.stack([k, v], dim=1).contiguous()
|
| 216 |
+
return flash_attn_varlen_kvpacked_func(
|
| 217 |
+
q,
|
| 218 |
+
kv_packed,
|
| 219 |
+
cu_seqlens_q,
|
| 220 |
+
cu_seqlens_k,
|
| 221 |
+
max_seqlen_q,
|
| 222 |
+
max_seqlen_k,
|
| 223 |
+
dropout_p=0.0,
|
| 224 |
+
softmax_scale=None,
|
| 225 |
+
causal=causal,
|
| 226 |
+
)
|
| 227 |
+
except Exception:
|
| 228 |
+
raise NotImplementedError("FA-2 varlen API not available; caller should fallback")
|
nsa/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
nsa/model/llama_block_nsa.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from nsa.cache.kv_cache import NSA_KV
|
| 7 |
+
from nsa.core.block_index import build_block_meta
|
| 8 |
+
from nsa.core.nsa_attention import NSAAttention
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RMSNorm(nn.Module):
|
| 12 |
+
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 15 |
+
self.eps = eps
|
| 16 |
+
|
| 17 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
# x: [B,S,dim]
|
| 19 |
+
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
|
| 20 |
+
return (x * rms) * self.weight
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MLP(nn.Module):
|
| 24 |
+
def __init__(self, dim: int, hidden_mult: int = 4) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
h = hidden_mult * dim
|
| 27 |
+
self.fc1 = nn.Linear(dim, h, bias=False)
|
| 28 |
+
self.fc2 = nn.Linear(h, dim, bias=False)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
return self.fc2(F.silu(self.fc1(x)))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class LlamaBlockNSA(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
dim: int,
|
| 38 |
+
n_heads: int,
|
| 39 |
+
n_kv_groups: int,
|
| 40 |
+
d_k: int,
|
| 41 |
+
d_v: int,
|
| 42 |
+
l: int = 32,
|
| 43 |
+
d: int = 16,
|
| 44 |
+
l_sel: int = 64,
|
| 45 |
+
n_sel: int = 16,
|
| 46 |
+
w: int = 512,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.norm1 = RMSNorm(dim)
|
| 50 |
+
self.attn = NSAAttention(
|
| 51 |
+
dim=dim,
|
| 52 |
+
n_heads=n_heads,
|
| 53 |
+
n_kv_groups=n_kv_groups,
|
| 54 |
+
d_k=d_k,
|
| 55 |
+
d_v=d_v,
|
| 56 |
+
l=l,
|
| 57 |
+
d=d,
|
| 58 |
+
l_sel=l_sel,
|
| 59 |
+
n_sel=n_sel,
|
| 60 |
+
w=w,
|
| 61 |
+
)
|
| 62 |
+
self.norm2 = RMSNorm(dim)
|
| 63 |
+
self.mlp = MLP(dim)
|
| 64 |
+
|
| 65 |
+
def _build_empty_kv(self, x: torch.Tensor) -> NSA_KV:
|
| 66 |
+
B, S, dim = x.shape
|
| 67 |
+
device = x.device
|
| 68 |
+
G = self.attn.n_kv_groups
|
| 69 |
+
Dk = self.attn.d_k
|
| 70 |
+
Dv = self.attn.d_v
|
| 71 |
+
zeros_k = torch.zeros((B, G, 0, Dk), device=device, dtype=x.dtype)
|
| 72 |
+
zeros_v = torch.zeros((B, G, 0, Dv), device=device, dtype=x.dtype)
|
| 73 |
+
meta = build_block_meta(
|
| 74 |
+
seq_len=0,
|
| 75 |
+
l=self.attn.l,
|
| 76 |
+
d=self.attn.d,
|
| 77 |
+
l_sel=self.attn.l_sel,
|
| 78 |
+
n_sel=self.attn.n_sel,
|
| 79 |
+
w=self.attn.w,
|
| 80 |
+
)
|
| 81 |
+
return NSA_KV(
|
| 82 |
+
K_sel=zeros_k.clone(),
|
| 83 |
+
V_sel=zeros_v.clone(),
|
| 84 |
+
K_win=zeros_k.clone(),
|
| 85 |
+
V_win=zeros_v.clone(),
|
| 86 |
+
K_cmp_raw_seq=zeros_k.clone(),
|
| 87 |
+
V_cmp_raw_seq=zeros_v.clone(),
|
| 88 |
+
K_cmp=zeros_k.clone(),
|
| 89 |
+
V_cmp=zeros_v.clone(),
|
| 90 |
+
win_ptr=torch.zeros((B, G), dtype=torch.int64, device=device),
|
| 91 |
+
cmp_emit_next=torch.zeros((B, G), dtype=torch.int64, device=device),
|
| 92 |
+
meta=meta,
|
| 93 |
+
reads_pred=torch.zeros((0,), dtype=torch.int64, device=device),
|
| 94 |
+
reads_act_total=torch.zeros((0,), dtype=torch.int64, device=device),
|
| 95 |
+
reads_act_sel=torch.zeros((0,), dtype=torch.int64, device=device),
|
| 96 |
+
reads_act_cmp=torch.zeros((0,), dtype=torch.int64, device=device),
|
| 97 |
+
reads_act_win=torch.zeros((0,), dtype=torch.int64, device=device),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def forward_attn(self, x: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
"""Attention sub-layer with residual.
|
| 102 |
+
|
| 103 |
+
Exposed to allow gradient-checkpoint splits that exclude attention from
|
| 104 |
+
checkpointing when dynamic routing could cause recompute mismatches.
|
| 105 |
+
"""
|
| 106 |
+
B, S, dim = x.shape
|
| 107 |
+
res = x
|
| 108 |
+
xn = self.norm1(x)
|
| 109 |
+
kv = self._build_empty_kv(x)
|
| 110 |
+
out, _kv = self.attn(xn, kv=kv, prefill=True)
|
| 111 |
+
return res + out
|
| 112 |
+
|
| 113 |
+
def forward_mlp(self, x: torch.Tensor) -> torch.Tensor:
|
| 114 |
+
"""MLP sub-layer with residual.
|
| 115 |
+
|
| 116 |
+
Can be safely checkpointed independently from attention.
|
| 117 |
+
"""
|
| 118 |
+
res = x
|
| 119 |
+
return res + self.mlp(self.norm2(x))
|
| 120 |
+
|
| 121 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
# Default monolithic forward preserves prior behavior
|
| 123 |
+
x = self.forward_attn(x)
|
| 124 |
+
x = self.forward_mlp(x)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class _EmptyKVLike:
|
| 129 |
+
pass
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
tokenization_nsa.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Remote code: configuration and modeling for NSA
|
| 2 |
+
from typing import List, Optional, Dict
|
| 3 |
+
import json
|
| 4 |
+
from transformers import PreTrainedTokenizer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class NSAByteTokenizer(PreTrainedTokenizer):
|
| 8 |
+
"""A simple byte-level tokenizer with fixed vocab size 256.
|
| 9 |
+
|
| 10 |
+
- Encodes UTF-8 bytes of the input string as token ids 0..255.
|
| 11 |
+
- No special tokens by default; EOS/PAD can be configured via special tokens map.
|
| 12 |
+
- Decoding uses UTF-8 with replacement for invalid sequences.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, **kwargs):
|
| 16 |
+
# Build a stable 256-entry vocab mapping before base init (base may query the vocab)
|
| 17 |
+
self._vocab: Dict[str, int] = {f"<{i}>": i for i in range(256)}
|
| 18 |
+
self._ids_to_tokens: Dict[int, str] = {i: f"<{i}>" for i in range(256)}
|
| 19 |
+
super().__init__(**kwargs)
|
| 20 |
+
# Only return input_ids and attention_mask to avoid unused token_type_ids in generation
|
| 21 |
+
self.model_input_names = ["input_ids", "attention_mask"]
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def vocab_size(self) -> int: # type: ignore[override]
|
| 25 |
+
return 256
|
| 26 |
+
|
| 27 |
+
def get_vocab(self) -> Dict[str, int]: # type: ignore[override]
|
| 28 |
+
return dict(self._vocab)
|
| 29 |
+
|
| 30 |
+
def _tokenize(self, text: str) -> List[str]: # type: ignore[override]
|
| 31 |
+
data = text.encode("utf-8", errors="replace")
|
| 32 |
+
return [f"<{b}>" for b in data]
|
| 33 |
+
|
| 34 |
+
def _convert_token_to_id(self, token: str) -> int: # type: ignore[override]
|
| 35 |
+
if token in self._vocab:
|
| 36 |
+
return self._vocab[token]
|
| 37 |
+
# Fallback: try parse numeric inside <..>
|
| 38 |
+
if token.startswith("<") and token.endswith(">"):
|
| 39 |
+
try:
|
| 40 |
+
v = int(token[1:-1])
|
| 41 |
+
if 0 <= v < 256:
|
| 42 |
+
return v
|
| 43 |
+
except Exception:
|
| 44 |
+
pass
|
| 45 |
+
return 0
|
| 46 |
+
|
| 47 |
+
def _convert_id_to_token(self, index: int) -> str: # type: ignore[override]
|
| 48 |
+
return self._ids_to_tokens.get(int(index) % 256, "<0>")
|
| 49 |
+
|
| 50 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str: # type: ignore[override]
|
| 51 |
+
bs = []
|
| 52 |
+
for t in tokens:
|
| 53 |
+
if t in self._vocab:
|
| 54 |
+
bs.append(self._vocab[t])
|
| 55 |
+
else:
|
| 56 |
+
try:
|
| 57 |
+
if t.startswith("<") and t.endswith(">"):
|
| 58 |
+
v = int(t[1:-1])
|
| 59 |
+
if 0 <= v < 256:
|
| 60 |
+
bs.append(v)
|
| 61 |
+
continue
|
| 62 |
+
except Exception:
|
| 63 |
+
pass
|
| 64 |
+
return bytes(bs).decode("utf-8", errors="replace")
|
| 65 |
+
|
| 66 |
+
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: # type: ignore[override]
|
| 67 |
+
if token_ids_1 is None:
|
| 68 |
+
return token_ids_0
|
| 69 |
+
return token_ids_0 + token_ids_1
|
| 70 |
+
|
| 71 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None): # type: ignore[override]
|
| 72 |
+
# Nothing to save besides special tokens map handled by the base class.
|
| 73 |
+
return (), ()
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tokenizer_class": "NSAByteTokenizer",
|
| 3 |
+
"model_max_length": 2048,
|
| 4 |
+
"chat_template": "{% for m in messages %}{% if m['role']=='user' %}<|user|>{{ m['content'] }}\n{% elif m['role']=='assistant' %}<|assistant|>{{ m['content'] }}\n{% endif %}{% endfor %}<|assistant|>",
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoTokenizer": [
|
| 7 |
+
"tokenization_nsa.NSAByteTokenizer",
|
| 8 |
+
null
|
| 9 |
+
]
|
| 10 |
+
}
|
| 11 |
+
}
|