ann-sparseattention
Search projections for ANN-substituted attention on
Qwen/Qwen3-4B-Instruct-2507.
Code: github.com/unixsysdev/ann-sparseattention
Current status
Research prototype. Trained projections work, runtime is a correctness prototype, eval envelope is narrow. Treat reported numbers as preliminary.
Validated: 6-layer pilot on Qwen3-4B-Instruct-2507; WikiText-103 PPL preserved at K=128 (gap β +0.7%); learned projections retrieve attention- relevant keys.
Not yet validated: 34-layer / whole-model substitution; long-context tasks (LongBench, RULER, needle); wall-clock speedup vs FlashAttention/SDPA; KV-cache decode-mode integration; GPU-resident ANN kernel.
Runtime caveat: the FAISS path here builds CPU indexes per batch and the gather step uses dense-style tensor expansion. Compute-reduction numbers below are algorithmic scoring reductions, not measured wall-clock speedups.
Relation to RetrievalAttention
RetrievalAttention (Liu et al., 2024) shows that vanilla ANN over the model's native Q, K vectors fails because Q and K live in mismatched distributions β they were never trained to be each other's nearest neighbors, only to score via dot product. Their fix is at index time: an attention-aware graph construction (RoarGraph-style).
This work attacks the same problem from the opposite direction. We
train a tiny shared projection (W_Qs, W_Ks β R^64) so that
q_search and k_search live in the same distribution by construction.
Off-the-shelf FAISS HNSW with default parameters then suffices.
| Search space | Index | Trainable | |
|---|---|---|---|
| Raw Q/K + vanilla ANN | original Q/K | off-the-shelf | no β fails (Q/K OOD) |
| RetrievalAttention | original Q/K | attention-aware graph | no |
| This work | learned Q_s / K_s | off-the-shelf | yes (~2-11M params) |
Contribution: eliminate Q/K mismatch at index-build time via distillation, instead of patching it at search time. The clean validating experiment β vanilla FAISS over raw Q/K vs. learned Q_s/K_s vs. exact teacher top-K β is the next planned run.
What's in this repo
Per-layer linear search projections (W_Qs, W_Ks) of shape [2560, 64],
trained against the frozen base model's attention via contrastive +
distillation losses. At inference these produce 64-d "search vectors" that
let an off-the-shelf FAISS HNSW index pick the top-K keys to attend to,
replacing dense O(LΒ²) attention with O(LΒ·K) ANN-substituted attention.
Layers covered (pilot): [4, 8, 12, 16, 20, 24] β 6 of 36 layers, ~2M trainable params.
Pilot results (final, 2K steps on WikiText-103)
| Step | Recall@K=128 | PPL gap (full vs ANN) |
|---|---|---|
| 500 | 47.4% | 1.21% |
| 1000 | 50.7% | 0.68% |
| 1500 | 50.9% | 0.68% |
| 2000 (final) | 50.9% | 0.71% |
PPL gap is the primary signal β at <1% relative gap, the model's output quality is preserved under ANN substitution. Recall plateaus around step 1000 because the softmax-relevant keys concentrate in the top ~30; disagreement on positions 30-128 is on near-zero-weight tail and doesn't affect output.
K-retrieve Pareto (pilot step 2000, FAISS HNSW)
PPL_full = 9.958
| K | Recall@K | PPL_ANN | PPL gap |
|---|---|---|---|
| 16 | 24.9% | 10.71 | +7.51% |
| 32 | 22.8% | 10.41 | +4.51% |
| 64 | 23.1% | 10.20 | +2.42% |
| 128 | 26.0% | 10.04 | +0.82% |
| 256 | 31.6% | 9.88 | β0.79% |
| 512 | 40.8% | 9.67 | β2.89% |
On this small WikiText slice, K β₯ 256 produced lower measured PPL than the full-attention reference. A plausible explanation is sparse-softmax denoising, but with 12 eval batches, sample noise, packed-boundary artifacts (pilot trained with packing on; default in the repo is now off), and partial-layer substitution acting like regularization are also candidates. Treating it as a hypothesis to confirm via an exact-topK oracle (full QK^T β top-K β restricted attention) at the same K β that separates "denoising from any sparsity" from "denoising from learned projections."
Code-level sanity checks pass: same input sequences for ppl_full vs
ppl_ann, intact causal mask in retrieval, single softmax over retrieved
K with no wrapper leakage between iterations.
Compute / quality knobs (FLOP-counted)
L = 4096. Compute reduction is the attention scoring step, β L / K.
These are FLOP estimates, not measured wall-clock β the FAISS path in this
repo is a research prototype that does CPU index builds and GPUβCPU
transfers, so it is not the right thing to time.
| K | PPL gap | Attention scoring reduction |
|---|---|---|
| 512 | β2.89% | ~8Γ |
| 256 | β0.79% | ~16Γ |
| 128 | +0.82% | ~32Γ |
| 64 | +2.42% | ~64Γ |
| 32 | +4.51% | ~128Γ |
| 16 | +7.51% | ~256Γ |
Eval scope: 12 sequences Γ 4K tokens of WikiText-103 validation (~50K tokens). Read these as "what we observed on this slice", not population- level estimates.
The K-sweep recall numbers (24β41%) and the in-training evaluate() recall
(50.9% at K=128) come from different sampled subsets of the streaming split
and shouldn't be directly compared. The repo also reports mass@K (sum of
teacher attention probability captured by the search top-K) β that's the
more direct retrieval-quality metric when softmax is sharp.
Per-layer recall (pilot)
| Layer | Recall@K=128 | Recall@K=512 |
|---|---|---|
| 4 | 15.8% | 34.7% |
| 8 | 22.2% | 38.7% |
| 12 | 23.4% | 39.1% |
| 16 | 31.9% | 45.2% |
| 20 | 31.4% | 42.6% |
| 24 | 31.1% | 44.4% |
Early layers are harder for content-addressable retrieval β their attention is more local/positional than semantic. Consistent across K, so it's a property of the layer rather than noise.
Caveats / what's next
- Packing: pilot training and eval ran with sequence packing on (no segment-level causal mask, since transformers' default forward doesn't build them). The relative PPL gap between full and ANN is internally consistent under this confound, but the negative gap at Kβ₯256 has at least three candidate explanations we haven't disentangled β (a) sparse-softmax denoising, (b) ANN happening to filter cross-document keys that full attention attends to, (c) sample noise on a small eval. The default config now has packing off so the next run isolates (a).
- Exact-topK oracle: a four-way Pareto (full vs. exact top-K vs. search-topK exact vs. search-ANN) is the natural follow-up to separate "denoising from any sparsity" from "denoising from learned projections."
- Wall-clock: not measured. The FAISS path in the repo is a CPU-side research prototype, not a deployable runtime. A GPU-resident topk kernel is the next-step engineering.
- 34-layer headline was queued (
make_headline_config()is wired) and will mirror its checkpoints here when it runs.
Files
| File | What |
|---|---|
search_step_1000.pt |
Mid-training checkpoint (step 1000, 0.68% PPL gap) |
search_step_2000.pt |
Final pilot checkpoint (step 2000, 0.71% PPL gap) |
Each contains {step, search_module: state_dict, optimizer, scheduler, config}.
Loading
import torch
from transformers import AutoModelForCausalLM
# Search module class is in the GitHub repo (model.py)
from model import SearchProjectionModule
base = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-4B-Instruct-2507",
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
)
search = SearchProjectionModule(
d_model=2560, d_search=64,
layer_indices=[4, 8, 12, 16, 20, 24],
use_mlp=False,
).to(base.device).to(torch.bfloat16)
ckpt = torch.load("search_step_2000.pt", map_location="cpu", weights_only=False)
search.load_state_dict(ckpt["search_module"])
Use inference.install_ann_attention(...) (in the GitHub repo) to monkey-patch
the trained layers and run with FAISS HNSW retrieval at inference time.
Training recipe
- Frozen base: Qwen3-4B-Instruct-2507 (36 layers, hidden 2560, GQA 32:8).
- Data: WikiText-103 raw, 4K-token sequences (packing was on at training time; default in the repo is now off β see Caveats).
- 2000 steps, batch 8, lr 1e-4 (cosine, 100-step warmup), AdamW.
Ξ±=Ξ²=1(contrastive + KL distillation, both layers averaged).- bf16 weights, fp32 loss math.
- SDPA attention (B200, no flash-attn package needed).
- Liger fused RMSNorm/SwiGLU/RoPE on the frozen base.
- Total wall-clock: ~25 min on a single B200.
License
The search projections are released under Apache-2.0 (matching the base model).
Model tree for datasysdev/ann-sparseattention
Base model
Qwen/Qwen3-4B-Instruct-2507