PyTorch
llama
hyx21 commited on
Commit
ab6c8f7
·
verified ·
1 Parent(s): 751f3d6

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +164 -0
README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ <div align="center">
5
+ <h1>DashAttention</h1>
6
+
7
+ <p><strong>Differentiable and Adaptive Sparse Hierarchical Attention</strong></p>
8
+ </div>
9
+
10
+ <div align="center" style="line-height: 1;">
11
+ <a href="https://github.com/fasa-org/dash-attention" style="margin: 2px;">
12
+ <img alt="Code" src="https://img.shields.io/badge/GitHub-100000?style=for-the-badge&logo=github&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
13
+ </a>
14
+ <a href="https://huggingface.co/collections/fasa-org/dashattention" style="margin: 2px;">
15
+ <img alt="Hugging Face" src="https://img.shields.io/badge/DashAttention-fcd022?style=for-the-badge&logo=huggingface&logoColor=000&labelColor" style="display: inline-block; vertical-align: middle;"/>
16
+ </a>
17
+ <a href="https://arxiv.org/abs/2605.18753" style="margin: 2px;">
18
+ <img alt="Paper" src="https://img.shields.io/badge/Paper-2605.18753-b31b1b.svg" style="display: inline-block; vertical-align: middle;"/>
19
+ </a>
20
+ </div>
21
+
22
+
23
+
24
+ ## Installation
25
+
26
+ For the usage of DashAttention kernels and running the example, please run the following script:
27
+ ```
28
+ pip install -e .
29
+ ```
30
+
31
+ For benchmark environment setup, please refer to each corresponding folder.
32
+
33
+ ## Usage
34
+
35
+ The dash attention's interface can be used as follows:
36
+
37
+ ```python
38
+ queries = torch.randn(batch, query_heads, seq_len, head_dim, device=device, dtype=dtype).contiguous()
39
+ keys = torch.randn(batch, kv_heads, seq_len, head_dim, device=device, dtype=dtype).contiguous()
40
+ values = torch.randn(batch, kv_heads, seq_len, head_dim, device=device, dtype=dtype).contiguous()
41
+ head_cls = torch.randn(kv_heads, head_dim, device=device, dtype=dtype).contiguous()
42
+
43
+ model = dash_attn(
44
+ chunk_size=chunk_size,
45
+ enable_gqa=True,
46
+ estimate_diagonal=True,
47
+ return_active_blocks=True,
48
+ )
49
+
50
+ out, active_blocks = model(queries, keys, values, head_cls)
51
+ ```
52
+
53
+ We also provide an example on how to use DashAttention in Llama-architecture models in [here](./example/run_niah.py).
54
+ ```
55
+ python ./example/run_niah.py
56
+ ```
57
+
58
+ ## Documentation
59
+
60
+ DashAttention implements the attention mechanism introduced in [DashAttention: Differentiable and Adaptive Sparse Hierarchical Attention](https://arxiv.org/abs/2605.18753). The method replaces fixed-budget top-k block routing with an adaptive, differentiable sparse router, then refines the selected regions with token-level softmax attention.
61
+
62
+ ### How it works
63
+
64
+ The implementation follows the three-stage hierarchy described in the paper:
65
+
66
+ 1. **Local chunk summarization**: `dash_attn.prefill.summarize_chunk` and `dash_attn.decoding.summarize_chunk` build one learned key summary per KV chunk.
67
+ 2. **Entmax block routing**: `score_blocks` computes sparse chunk supports and routing priors from query-to-summary scores.
68
+ 3. **Prior-induced sparse softmax**: `full_attn` applies token-level attention only over routed chunks, using the Stage 1 prior to preserve differentiability through the hierarchy.
69
+
70
+ The public kernel wrapper is [`dash_attn.dash_attn_interface.dash_attn`](./dash_attn/dash_attn_interface.py). It supports both prefill and decoding: prefill summarizes the current sequence and stores complete chunk summaries, while decoding reuses the chunk-summary cache and appends newly completed chunks.
71
+
72
+ ### Core API
73
+
74
+ ```python
75
+ from dash_attn import dash_attn
76
+
77
+ attn = dash_attn(
78
+ chunk_size=64,
79
+ enable_gqa=True,
80
+ estimate_diagonal=True,
81
+ scaling_factor=1.0,
82
+ return_active_blocks=False,
83
+ )
84
+ ```
85
+
86
+ Important arguments:
87
+
88
+ | Argument | Description |
89
+ |:-|:-|
90
+ | `chunk_size` | Number of tokens per routed KV chunk. |
91
+ | `enable_gqa` | Enables grouped-query attention support when query heads outnumber KV heads. |
92
+ | `estimate_diagonal` | Includes special handling for the current or near-diagonal chunk. |
93
+ | `scaling_factor` | Scales routing logits before sparse block selection; this is the main knob for sparsity. |
94
+ | `return_active_blocks` | Returns the number of active routed blocks per token for sparsity analysis. |
95
+ | `max_chunks` | Preallocated chunk-summary cache capacity used during decoding. |
96
+ | `sigma` | Controls the strength of the Stage 1 routing prior used by Stage 2. |
97
+
98
+ Inputs are expected in `[batch, heads, seq_len, head_dim]` layout for `queries`, `keys`, and `values`; `head_cls` has shape `[kv_heads, head_dim]`.
99
+
100
+ ### Llama integration
101
+
102
+ DashAttention includes a Llama-compatible modeling implementation in [`dash_attn.models.llama`](./dash_attn/models/llama). `LlamaConfig` defaults to `attn_implementation="dash_attn"` and adds DashAttention-specific fields such as `chunk_size`, `estimate_diagonal`, `sigma`, and `scaling_factor`.
103
+
104
+ ```python
105
+ from dash_attn.models.llama import LlamaForCausalLM
106
+
107
+ model = LlamaForCausalLM.from_pretrained(
108
+ "fasa-org/MiniCPM-4-8B-DashAttention",
109
+ attn_implementation="dash_attn",
110
+ torch_dtype="auto",
111
+ )
112
+ ```
113
+
114
+ To inspect routing behavior, call the model with `return_active_blocks=True`, then read `model.get_active_blocks()`.
115
+
116
+ ### Examples and tests
117
+
118
+ - [`example/run_niah.py`](./example/run_niah.py) runs a needle-in-a-haystack style generation example and reports measured sparsity.
119
+ - [`test/test_smoke.py`](./test/test_smoke.py) checks the standalone DashAttention kernel wrapper.
120
+ - [`test/test_llama_dash_attn.py`](./test/test_llama_dash_attn.py) checks the Llama integration and active-block reporting.
121
+
122
+ Run the test suite with:
123
+
124
+ ```bash
125
+ pytest
126
+ ```
127
+
128
+ The current kernels require CUDA-capable hardware.
129
+
130
+ ## Models
131
+
132
+ We release our 8B models for reproducibility.
133
+
134
+ | Model | Link |
135
+ |:-:|:-:|
136
+ | 8B-FullAttn | [Hugging Face](https://huggingface.co/fasa-org/MiniCPM-4-8B-FullAttn) |
137
+ | 8B-InfLLMv2 | [Hugging Face](https://huggingface.co/fasa-org/MiniCPM-4-8B-InfLLMv2) |
138
+ | 8B-NSA | [Hugging Face](https://huggingface.co/fasa-org/MiniCPM-4-8B-NSA) |
139
+ | 8B-DashAttention | [Hugging Face](https://huggingface.co/fasa-org/MiniCPM-4-8B-DashAttention) |
140
+
141
+ The base models we use are [MiniCPM4-1B-Base](https://modelscope.cn/models/OpenBMB/MiniCPM4-1B-Base), [MiniCPM4-3B-Base](https://modelscope.cn/models/OpenBMB/MiniCPM4-3B-Base), and [MiniCPM4-8B-Base](https://modelscope.cn/models/OpenBMB/MiniCPM4-8B-Base).
142
+
143
+ ## Benchmarks
144
+
145
+ - Performance: Please refer to [README](./benchmarks/performance/README.md).
146
+
147
+ ## License
148
+
149
+ This project is released under the [BSD-3-Clause License](./LICENSE).
150
+
151
+ ## Acknowledgement
152
+
153
+ This repository is developed with the aid of [RULER](https://github.com/NVIDIA/RULER), [OLMES](https://github.com/allenai/olmes), [InfLLMv2](https://github.com/OpenBMB/infllmv2_cuda_impl), and [NSA-triton](https://github.com/XunhaoLai/native-sparse-attention-triton).
154
+
155
+ ## Citation
156
+
157
+ ```latex
158
+ @article{dash-attention,
159
+ title={DashAttention: Differentiable and Adaptive Sparse Hierarchical Attention},
160
+ author={Huang, Yuxiang and Gon{\c{c}}alves, Nuno M. T. and Alvetreti, Federico and Li, Lei and Han, Xu and Ponti, Edoardo M. and Martins, Andr{\'e} F. T. and Treviso, Marcos V.},
161
+ journal={arXiv preprint arXiv:2605.18753},
162
+ year={2026}
163
+ }
164
+ ```