ann-sparseattention / README.md
datasysdev's picture
Upload README.md with huggingface_hub
66efa56 verified
---
language:
- en
license: apache-2.0
base_model: Qwen/Qwen3-4B-Instruct-2507
tags:
- sparse-attention
- ann-attention
- distillation
- search-projection
- inference-optimization
library_name: pytorch
---
# ann-sparseattention
Search projections for ANN-substituted attention on
[`Qwen/Qwen3-4B-Instruct-2507`](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507).
Code: [github.com/unixsysdev/ann-sparseattention](https://github.com/unixsysdev/ann-sparseattention)
## Current status
Research prototype. Trained projections work, runtime is a correctness
prototype, eval envelope is narrow. Treat reported numbers as preliminary.
**Validated:** 6-layer pilot on Qwen3-4B-Instruct-2507; WikiText-103 PPL
preserved at K=128 (gap β‰ˆ +0.7%); learned projections retrieve attention-
relevant keys.
**Not yet validated:** 34-layer / whole-model substitution; long-context
tasks (LongBench, RULER, needle); wall-clock speedup vs FlashAttention/SDPA;
KV-cache decode-mode integration; GPU-resident ANN kernel.
**Runtime caveat:** the FAISS path here builds CPU indexes per batch and
the gather step uses dense-style tensor expansion. Compute-reduction
numbers below are *algorithmic scoring reductions, not measured wall-clock
speedups.*
## Relation to RetrievalAttention
RetrievalAttention (Liu et al., 2024) shows that **vanilla ANN over the
model's native Q, K vectors fails** because Q and K live in mismatched
distributions β€” they were never trained to be each other's nearest
neighbors, only to score via dot product. Their fix is at *index time*:
an attention-aware graph construction (RoarGraph-style).
This work attacks the same problem from the opposite direction. We
**train a tiny shared projection** (`W_Qs, W_Ks β†’ R^64`) so that
`q_search` and `k_search` live in the same distribution by construction.
Off-the-shelf FAISS HNSW with default parameters then suffices.
| | Search space | Index | Trainable |
|---|---|---|---|
| Raw Q/K + vanilla ANN | original Q/K | off-the-shelf | no β€” fails (Q/K OOD) |
| RetrievalAttention | original Q/K | attention-aware graph | no |
| **This work** | **learned Q\_s / K\_s** | **off-the-shelf** | **yes (~2-11M params)** |
Contribution: *eliminate Q/K mismatch at index-build time via distillation,
instead of patching it at search time.* The clean validating experiment β€”
vanilla FAISS over raw Q/K vs. learned Q\_s/K\_s vs. exact teacher top-K
β€” is the next planned run.
## What's in this repo
Per-layer linear search projections `(W_Qs, W_Ks)` of shape `[2560, 64]`,
trained against the frozen base model's attention via contrastive +
distillation losses. At inference these produce 64-d "search vectors" that
let an off-the-shelf FAISS HNSW index pick the top-K keys to attend to,
replacing dense `O(LΒ²)` attention with `O(LΒ·K)` ANN-substituted attention.
Layers covered (pilot): `[4, 8, 12, 16, 20, 24]` β€” 6 of 36 layers, ~2M trainable params.
## Pilot results (final, 2K steps on WikiText-103)
| Step | Recall@K=128 | PPL gap (full vs ANN) |
|---|---|---|
| 500 | 47.4% | 1.21% |
| 1000 | 50.7% | 0.68% |
| 1500 | 50.9% | 0.68% |
| **2000 (final)** | **50.9%** | **0.71%** |
PPL gap is the primary signal β€” at <1% relative gap, the model's output
quality is preserved under ANN substitution. Recall plateaus around step 1000
because the softmax-relevant keys concentrate in the top ~30; disagreement
on positions 30-128 is on near-zero-weight tail and doesn't affect output.
### K-retrieve Pareto (pilot step 2000, FAISS HNSW)
`PPL_full = 9.958`
| K | Recall@K | PPL_ANN | PPL gap |
|---|---|---|---|
| 16 | 24.9% | 10.71 | +7.51% |
| 32 | 22.8% | 10.41 | +4.51% |
| 64 | 23.1% | 10.20 | +2.42% |
| 128 | 26.0% | 10.04 | +0.82% |
| 256 | 31.6% | 9.88 | **βˆ’0.79%** |
| 512 | 40.8% | 9.67 | **βˆ’2.89%** |
On this small WikiText slice, K β‰₯ 256 produced lower measured PPL than
the full-attention reference. A plausible explanation is sparse-softmax
denoising, but with 12 eval batches, sample noise, packed-boundary artifacts
(pilot trained with packing on; default in the repo is now off), and
partial-layer substitution acting like regularization are also candidates.
Treating it as a hypothesis to confirm via an exact-topK oracle (full QK^T
β†’ top-K β†’ restricted attention) at the same K β€” that separates "denoising
from any sparsity" from "denoising from learned projections."
Code-level sanity checks pass: same input sequences for `ppl_full` vs
`ppl_ann`, intact causal mask in retrieval, single softmax over retrieved
K with no wrapper leakage between iterations.
### Compute / quality knobs (FLOP-counted)
`L = 4096`. Compute reduction is the attention scoring step, β‰ˆ `L / K`.
These are FLOP estimates, not measured wall-clock β€” the FAISS path in this
repo is a research prototype that does CPU index builds and GPU↔CPU
transfers, so it is not the right thing to time.
| K | PPL gap | Attention scoring reduction |
|---|---|---|
| 512 | βˆ’2.89% | ~8Γ— |
| 256 | βˆ’0.79% | ~16Γ— |
| 128 | +0.82% | ~32Γ— |
| 64 | +2.42% | ~64Γ— |
| 32 | +4.51% | ~128Γ— |
| 16 | +7.51% | ~256Γ— |
Eval scope: 12 sequences Γ— 4K tokens of WikiText-103 validation (~50K
tokens). Read these as "what we observed on this slice", not population-
level estimates.
The K-sweep recall numbers (24–41%) and the in-training `evaluate()` recall
(50.9% at K=128) come from different sampled subsets of the streaming split
and shouldn't be directly compared. The repo also reports `mass@K` (sum of
teacher attention probability captured by the search top-K) β€” that's the
more direct retrieval-quality metric when softmax is sharp.
### Per-layer recall (pilot)
| Layer | Recall@K=128 | Recall@K=512 |
|---|---|---|
| 4 | 15.8% | 34.7% |
| 8 | 22.2% | 38.7% |
| 12 | 23.4% | 39.1% |
| 16 | 31.9% | 45.2% |
| 20 | 31.4% | 42.6% |
| 24 | 31.1% | 44.4% |
Early layers are harder for content-addressable retrieval β€” their attention
is more local/positional than semantic. Consistent across K, so it's a
property of the layer rather than noise.
### Caveats / what's next
- **Packing**: pilot training and eval ran with sequence packing on (no
segment-level causal mask, since transformers' default forward doesn't
build them). The relative PPL gap between full and ANN is internally
consistent under this confound, but the negative gap at Kβ‰₯256 has at
least three candidate explanations we haven't disentangled β€”
(a) sparse-softmax denoising, (b) ANN happening to filter cross-document
keys that full attention attends to, (c) sample noise on a small eval.
The default config now has packing off so the next run isolates (a).
- **Exact-topK oracle**: a four-way Pareto (full vs. exact top-K vs.
search-topK exact vs. search-ANN) is the natural follow-up to separate
"denoising from any sparsity" from "denoising from learned projections."
- **Wall-clock**: not measured. The FAISS path in the repo is a CPU-side
research prototype, not a deployable runtime. A GPU-resident topk kernel
is the next-step engineering.
- **34-layer headline** was queued (`make_headline_config()` is wired) and
will mirror its checkpoints here when it runs.
## Files
| File | What |
|---|---|
| `search_step_1000.pt` | Mid-training checkpoint (step 1000, 0.68% PPL gap) |
| `search_step_2000.pt` | Final pilot checkpoint (step 2000, 0.71% PPL gap) |
Each contains `{step, search_module: state_dict, optimizer, scheduler, config}`.
## Loading
```python
import torch
from transformers import AutoModelForCausalLM
# Search module class is in the GitHub repo (model.py)
from model import SearchProjectionModule
base = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-4B-Instruct-2507",
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
)
search = SearchProjectionModule(
d_model=2560, d_search=64,
layer_indices=[4, 8, 12, 16, 20, 24],
use_mlp=False,
).to(base.device).to(torch.bfloat16)
ckpt = torch.load("search_step_2000.pt", map_location="cpu", weights_only=False)
search.load_state_dict(ckpt["search_module"])
```
Use `inference.install_ann_attention(...)` (in the GitHub repo) to monkey-patch
the trained layers and run with FAISS HNSW retrieval at inference time.
## Training recipe
- Frozen base: Qwen3-4B-Instruct-2507 (36 layers, hidden 2560, GQA 32:8).
- Data: WikiText-103 raw, 4K-token sequences (packing was on at training
time; default in the repo is now off β€” see Caveats).
- 2000 steps, batch 8, lr 1e-4 (cosine, 100-step warmup), AdamW.
- `Ξ±=Ξ²=1` (contrastive + KL distillation, both layers averaged).
- bf16 weights, fp32 loss math.
- SDPA attention (B200, no flash-attn package needed).
- Liger fused RMSNorm/SwiGLU/RoPE on the frozen base.
- Total wall-clock: ~25 min on a single B200.
## License
The search projections are released under Apache-2.0 (matching the base model).