File size: 8,861 Bytes
01d6e75 720eddc 66efa56 01d6e75 50122dd 01d6e75 50122dd 01d6e75 50122dd 8f87cd2 720eddc 92ec570 b42f744 92ec570 b42f744 8f87cd2 b42f744 01d6e75 50122dd 01d6e75 50122dd 01d6e75 50122dd 01d6e75 b42f744 01d6e75 50122dd 01d6e75 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | ---
language:
- en
license: apache-2.0
base_model: Qwen/Qwen3-4B-Instruct-2507
tags:
- sparse-attention
- ann-attention
- distillation
- search-projection
- inference-optimization
library_name: pytorch
---
# ann-sparseattention
Search projections for ANN-substituted attention on
[`Qwen/Qwen3-4B-Instruct-2507`](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507).
Code: [github.com/unixsysdev/ann-sparseattention](https://github.com/unixsysdev/ann-sparseattention)
## Current status
Research prototype. Trained projections work, runtime is a correctness
prototype, eval envelope is narrow. Treat reported numbers as preliminary.
**Validated:** 6-layer pilot on Qwen3-4B-Instruct-2507; WikiText-103 PPL
preserved at K=128 (gap β +0.7%); learned projections retrieve attention-
relevant keys.
**Not yet validated:** 34-layer / whole-model substitution; long-context
tasks (LongBench, RULER, needle); wall-clock speedup vs FlashAttention/SDPA;
KV-cache decode-mode integration; GPU-resident ANN kernel.
**Runtime caveat:** the FAISS path here builds CPU indexes per batch and
the gather step uses dense-style tensor expansion. Compute-reduction
numbers below are *algorithmic scoring reductions, not measured wall-clock
speedups.*
## Relation to RetrievalAttention
RetrievalAttention (Liu et al., 2024) shows that **vanilla ANN over the
model's native Q, K vectors fails** because Q and K live in mismatched
distributions β they were never trained to be each other's nearest
neighbors, only to score via dot product. Their fix is at *index time*:
an attention-aware graph construction (RoarGraph-style).
This work attacks the same problem from the opposite direction. We
**train a tiny shared projection** (`W_Qs, W_Ks β R^64`) so that
`q_search` and `k_search` live in the same distribution by construction.
Off-the-shelf FAISS HNSW with default parameters then suffices.
| | Search space | Index | Trainable |
|---|---|---|---|
| Raw Q/K + vanilla ANN | original Q/K | off-the-shelf | no β fails (Q/K OOD) |
| RetrievalAttention | original Q/K | attention-aware graph | no |
| **This work** | **learned Q\_s / K\_s** | **off-the-shelf** | **yes (~2-11M params)** |
Contribution: *eliminate Q/K mismatch at index-build time via distillation,
instead of patching it at search time.* The clean validating experiment β
vanilla FAISS over raw Q/K vs. learned Q\_s/K\_s vs. exact teacher top-K
β is the next planned run.
## What's in this repo
Per-layer linear search projections `(W_Qs, W_Ks)` of shape `[2560, 64]`,
trained against the frozen base model's attention via contrastive +
distillation losses. At inference these produce 64-d "search vectors" that
let an off-the-shelf FAISS HNSW index pick the top-K keys to attend to,
replacing dense `O(LΒ²)` attention with `O(LΒ·K)` ANN-substituted attention.
Layers covered (pilot): `[4, 8, 12, 16, 20, 24]` β 6 of 36 layers, ~2M trainable params.
## Pilot results (final, 2K steps on WikiText-103)
| Step | Recall@K=128 | PPL gap (full vs ANN) |
|---|---|---|
| 500 | 47.4% | 1.21% |
| 1000 | 50.7% | 0.68% |
| 1500 | 50.9% | 0.68% |
| **2000 (final)** | **50.9%** | **0.71%** |
PPL gap is the primary signal β at <1% relative gap, the model's output
quality is preserved under ANN substitution. Recall plateaus around step 1000
because the softmax-relevant keys concentrate in the top ~30; disagreement
on positions 30-128 is on near-zero-weight tail and doesn't affect output.
### K-retrieve Pareto (pilot step 2000, FAISS HNSW)
`PPL_full = 9.958`
| K | Recall@K | PPL_ANN | PPL gap |
|---|---|---|---|
| 16 | 24.9% | 10.71 | +7.51% |
| 32 | 22.8% | 10.41 | +4.51% |
| 64 | 23.1% | 10.20 | +2.42% |
| 128 | 26.0% | 10.04 | +0.82% |
| 256 | 31.6% | 9.88 | **β0.79%** |
| 512 | 40.8% | 9.67 | **β2.89%** |
On this small WikiText slice, K β₯ 256 produced lower measured PPL than
the full-attention reference. A plausible explanation is sparse-softmax
denoising, but with 12 eval batches, sample noise, packed-boundary artifacts
(pilot trained with packing on; default in the repo is now off), and
partial-layer substitution acting like regularization are also candidates.
Treating it as a hypothesis to confirm via an exact-topK oracle (full QK^T
β top-K β restricted attention) at the same K β that separates "denoising
from any sparsity" from "denoising from learned projections."
Code-level sanity checks pass: same input sequences for `ppl_full` vs
`ppl_ann`, intact causal mask in retrieval, single softmax over retrieved
K with no wrapper leakage between iterations.
### Compute / quality knobs (FLOP-counted)
`L = 4096`. Compute reduction is the attention scoring step, β `L / K`.
These are FLOP estimates, not measured wall-clock β the FAISS path in this
repo is a research prototype that does CPU index builds and GPUβCPU
transfers, so it is not the right thing to time.
| K | PPL gap | Attention scoring reduction |
|---|---|---|
| 512 | β2.89% | ~8Γ |
| 256 | β0.79% | ~16Γ |
| 128 | +0.82% | ~32Γ |
| 64 | +2.42% | ~64Γ |
| 32 | +4.51% | ~128Γ |
| 16 | +7.51% | ~256Γ |
Eval scope: 12 sequences Γ 4K tokens of WikiText-103 validation (~50K
tokens). Read these as "what we observed on this slice", not population-
level estimates.
The K-sweep recall numbers (24β41%) and the in-training `evaluate()` recall
(50.9% at K=128) come from different sampled subsets of the streaming split
and shouldn't be directly compared. The repo also reports `mass@K` (sum of
teacher attention probability captured by the search top-K) β that's the
more direct retrieval-quality metric when softmax is sharp.
### Per-layer recall (pilot)
| Layer | Recall@K=128 | Recall@K=512 |
|---|---|---|
| 4 | 15.8% | 34.7% |
| 8 | 22.2% | 38.7% |
| 12 | 23.4% | 39.1% |
| 16 | 31.9% | 45.2% |
| 20 | 31.4% | 42.6% |
| 24 | 31.1% | 44.4% |
Early layers are harder for content-addressable retrieval β their attention
is more local/positional than semantic. Consistent across K, so it's a
property of the layer rather than noise.
### Caveats / what's next
- **Packing**: pilot training and eval ran with sequence packing on (no
segment-level causal mask, since transformers' default forward doesn't
build them). The relative PPL gap between full and ANN is internally
consistent under this confound, but the negative gap at Kβ₯256 has at
least three candidate explanations we haven't disentangled β
(a) sparse-softmax denoising, (b) ANN happening to filter cross-document
keys that full attention attends to, (c) sample noise on a small eval.
The default config now has packing off so the next run isolates (a).
- **Exact-topK oracle**: a four-way Pareto (full vs. exact top-K vs.
search-topK exact vs. search-ANN) is the natural follow-up to separate
"denoising from any sparsity" from "denoising from learned projections."
- **Wall-clock**: not measured. The FAISS path in the repo is a CPU-side
research prototype, not a deployable runtime. A GPU-resident topk kernel
is the next-step engineering.
- **34-layer headline** was queued (`make_headline_config()` is wired) and
will mirror its checkpoints here when it runs.
## Files
| File | What |
|---|---|
| `search_step_1000.pt` | Mid-training checkpoint (step 1000, 0.68% PPL gap) |
| `search_step_2000.pt` | Final pilot checkpoint (step 2000, 0.71% PPL gap) |
Each contains `{step, search_module: state_dict, optimizer, scheduler, config}`.
## Loading
```python
import torch
from transformers import AutoModelForCausalLM
# Search module class is in the GitHub repo (model.py)
from model import SearchProjectionModule
base = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-4B-Instruct-2507",
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
)
search = SearchProjectionModule(
d_model=2560, d_search=64,
layer_indices=[4, 8, 12, 16, 20, 24],
use_mlp=False,
).to(base.device).to(torch.bfloat16)
ckpt = torch.load("search_step_2000.pt", map_location="cpu", weights_only=False)
search.load_state_dict(ckpt["search_module"])
```
Use `inference.install_ann_attention(...)` (in the GitHub repo) to monkey-patch
the trained layers and run with FAISS HNSW retrieval at inference time.
## Training recipe
- Frozen base: Qwen3-4B-Instruct-2507 (36 layers, hidden 2560, GQA 32:8).
- Data: WikiText-103 raw, 4K-token sequences (packing was on at training
time; default in the repo is now off β see Caveats).
- 2000 steps, batch 8, lr 1e-4 (cosine, 100-step warmup), AdamW.
- `Ξ±=Ξ²=1` (contrastive + KL distillation, both layers averaged).
- bf16 weights, fp32 loss math.
- SDPA attention (B200, no flash-attn package needed).
- Liger fused RMSNorm/SwiGLU/RoPE on the frozen base.
- Total wall-clock: ~25 min on a single B200.
## License
The search projections are released under Apache-2.0 (matching the base model).
|