| --- |
| language: |
| - en |
| license: apache-2.0 |
| base_model: Qwen/Qwen3-4B-Instruct-2507 |
| tags: |
| - sparse-attention |
| - ann-attention |
| - distillation |
| - search-projection |
| - inference-optimization |
| library_name: pytorch |
| --- |
| |
| # ann-sparseattention |
|
|
| Search projections for ANN-substituted attention on |
| [`Qwen/Qwen3-4B-Instruct-2507`](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507). |
|
|
| Code: [github.com/unixsysdev/ann-sparseattention](https://github.com/unixsysdev/ann-sparseattention) |
|
|
| ## Current status |
|
|
| Research prototype. Trained projections work, runtime is a correctness |
| prototype, eval envelope is narrow. Treat reported numbers as preliminary. |
|
|
| **Validated:** 6-layer pilot on Qwen3-4B-Instruct-2507; WikiText-103 PPL |
| preserved at K=128 (gap β +0.7%); learned projections retrieve attention- |
| relevant keys. |
|
|
| **Not yet validated:** 34-layer / whole-model substitution; long-context |
| tasks (LongBench, RULER, needle); wall-clock speedup vs FlashAttention/SDPA; |
| KV-cache decode-mode integration; GPU-resident ANN kernel. |
|
|
| **Runtime caveat:** the FAISS path here builds CPU indexes per batch and |
| the gather step uses dense-style tensor expansion. Compute-reduction |
| numbers below are *algorithmic scoring reductions, not measured wall-clock |
| speedups.* |
|
|
| ## Relation to RetrievalAttention |
|
|
| RetrievalAttention (Liu et al., 2024) shows that **vanilla ANN over the |
| model's native Q, K vectors fails** because Q and K live in mismatched |
| distributions β they were never trained to be each other's nearest |
| neighbors, only to score via dot product. Their fix is at *index time*: |
| an attention-aware graph construction (RoarGraph-style). |
|
|
| This work attacks the same problem from the opposite direction. We |
| **train a tiny shared projection** (`W_Qs, W_Ks β R^64`) so that |
| `q_search` and `k_search` live in the same distribution by construction. |
| Off-the-shelf FAISS HNSW with default parameters then suffices. |
|
|
| | | Search space | Index | Trainable | |
| |---|---|---|---| |
| | Raw Q/K + vanilla ANN | original Q/K | off-the-shelf | no β fails (Q/K OOD) | |
| | RetrievalAttention | original Q/K | attention-aware graph | no | |
| | **This work** | **learned Q\_s / K\_s** | **off-the-shelf** | **yes (~2-11M params)** | |
|
|
| Contribution: *eliminate Q/K mismatch at index-build time via distillation, |
| instead of patching it at search time.* The clean validating experiment β |
| vanilla FAISS over raw Q/K vs. learned Q\_s/K\_s vs. exact teacher top-K |
| β is the next planned run. |
|
|
| ## What's in this repo |
|
|
| Per-layer linear search projections `(W_Qs, W_Ks)` of shape `[2560, 64]`, |
| trained against the frozen base model's attention via contrastive + |
| distillation losses. At inference these produce 64-d "search vectors" that |
| let an off-the-shelf FAISS HNSW index pick the top-K keys to attend to, |
| replacing dense `O(LΒ²)` attention with `O(LΒ·K)` ANN-substituted attention. |
|
|
| Layers covered (pilot): `[4, 8, 12, 16, 20, 24]` β 6 of 36 layers, ~2M trainable params. |
|
|
| ## Pilot results (final, 2K steps on WikiText-103) |
|
|
| | Step | Recall@K=128 | PPL gap (full vs ANN) | |
| |---|---|---| |
| | 500 | 47.4% | 1.21% | |
| | 1000 | 50.7% | 0.68% | |
| | 1500 | 50.9% | 0.68% | |
| | **2000 (final)** | **50.9%** | **0.71%** | |
|
|
| PPL gap is the primary signal β at <1% relative gap, the model's output |
| quality is preserved under ANN substitution. Recall plateaus around step 1000 |
| because the softmax-relevant keys concentrate in the top ~30; disagreement |
| on positions 30-128 is on near-zero-weight tail and doesn't affect output. |
|
|
| ### K-retrieve Pareto (pilot step 2000, FAISS HNSW) |
|
|
| `PPL_full = 9.958` |
|
|
| | K | Recall@K | PPL_ANN | PPL gap | |
| |---|---|---|---| |
| | 16 | 24.9% | 10.71 | +7.51% | |
| | 32 | 22.8% | 10.41 | +4.51% | |
| | 64 | 23.1% | 10.20 | +2.42% | |
| | 128 | 26.0% | 10.04 | +0.82% | |
| | 256 | 31.6% | 9.88 | **β0.79%** | |
| | 512 | 40.8% | 9.67 | **β2.89%** | |
| |
| On this small WikiText slice, K β₯ 256 produced lower measured PPL than |
| the full-attention reference. A plausible explanation is sparse-softmax |
| denoising, but with 12 eval batches, sample noise, packed-boundary artifacts |
| (pilot trained with packing on; default in the repo is now off), and |
| partial-layer substitution acting like regularization are also candidates. |
| Treating it as a hypothesis to confirm via an exact-topK oracle (full QK^T |
| β top-K β restricted attention) at the same K β that separates "denoising |
| from any sparsity" from "denoising from learned projections." |
| |
| Code-level sanity checks pass: same input sequences for `ppl_full` vs |
| `ppl_ann`, intact causal mask in retrieval, single softmax over retrieved |
| K with no wrapper leakage between iterations. |
|
|
| ### Compute / quality knobs (FLOP-counted) |
|
|
| `L = 4096`. Compute reduction is the attention scoring step, β `L / K`. |
| These are FLOP estimates, not measured wall-clock β the FAISS path in this |
| repo is a research prototype that does CPU index builds and GPUβCPU |
| transfers, so it is not the right thing to time. |
|
|
| | K | PPL gap | Attention scoring reduction | |
| |---|---|---| |
| | 512 | β2.89% | ~8Γ | |
| | 256 | β0.79% | ~16Γ | |
| | 128 | +0.82% | ~32Γ | |
| | 64 | +2.42% | ~64Γ | |
| | 32 | +4.51% | ~128Γ | |
| | 16 | +7.51% | ~256Γ | |
|
|
| Eval scope: 12 sequences Γ 4K tokens of WikiText-103 validation (~50K |
| tokens). Read these as "what we observed on this slice", not population- |
| level estimates. |
|
|
| The K-sweep recall numbers (24β41%) and the in-training `evaluate()` recall |
| (50.9% at K=128) come from different sampled subsets of the streaming split |
| and shouldn't be directly compared. The repo also reports `mass@K` (sum of |
| teacher attention probability captured by the search top-K) β that's the |
| more direct retrieval-quality metric when softmax is sharp. |
|
|
| ### Per-layer recall (pilot) |
|
|
| | Layer | Recall@K=128 | Recall@K=512 | |
| |---|---|---| |
| | 4 | 15.8% | 34.7% | |
| | 8 | 22.2% | 38.7% | |
| | 12 | 23.4% | 39.1% | |
| | 16 | 31.9% | 45.2% | |
| | 20 | 31.4% | 42.6% | |
| | 24 | 31.1% | 44.4% | |
|
|
| Early layers are harder for content-addressable retrieval β their attention |
| is more local/positional than semantic. Consistent across K, so it's a |
| property of the layer rather than noise. |
|
|
| ### Caveats / what's next |
|
|
| - **Packing**: pilot training and eval ran with sequence packing on (no |
| segment-level causal mask, since transformers' default forward doesn't |
| build them). The relative PPL gap between full and ANN is internally |
| consistent under this confound, but the negative gap at Kβ₯256 has at |
| least three candidate explanations we haven't disentangled β |
| (a) sparse-softmax denoising, (b) ANN happening to filter cross-document |
| keys that full attention attends to, (c) sample noise on a small eval. |
| The default config now has packing off so the next run isolates (a). |
| - **Exact-topK oracle**: a four-way Pareto (full vs. exact top-K vs. |
| search-topK exact vs. search-ANN) is the natural follow-up to separate |
| "denoising from any sparsity" from "denoising from learned projections." |
| - **Wall-clock**: not measured. The FAISS path in the repo is a CPU-side |
| research prototype, not a deployable runtime. A GPU-resident topk kernel |
| is the next-step engineering. |
| - **34-layer headline** was queued (`make_headline_config()` is wired) and |
| will mirror its checkpoints here when it runs. |
|
|
| ## Files |
|
|
| | File | What | |
| |---|---| |
| | `search_step_1000.pt` | Mid-training checkpoint (step 1000, 0.68% PPL gap) | |
| | `search_step_2000.pt` | Final pilot checkpoint (step 2000, 0.71% PPL gap) | |
|
|
| Each contains `{step, search_module: state_dict, optimizer, scheduler, config}`. |
|
|
| ## Loading |
|
|
| ```python |
| import torch |
| from transformers import AutoModelForCausalLM |
| # Search module class is in the GitHub repo (model.py) |
| from model import SearchProjectionModule |
| |
| base = AutoModelForCausalLM.from_pretrained( |
| "Qwen/Qwen3-4B-Instruct-2507", |
| dtype=torch.bfloat16, |
| device_map="auto", |
| attn_implementation="sdpa", |
| ) |
| |
| search = SearchProjectionModule( |
| d_model=2560, d_search=64, |
| layer_indices=[4, 8, 12, 16, 20, 24], |
| use_mlp=False, |
| ).to(base.device).to(torch.bfloat16) |
| |
| ckpt = torch.load("search_step_2000.pt", map_location="cpu", weights_only=False) |
| search.load_state_dict(ckpt["search_module"]) |
| ``` |
|
|
| Use `inference.install_ann_attention(...)` (in the GitHub repo) to monkey-patch |
| the trained layers and run with FAISS HNSW retrieval at inference time. |
|
|
| ## Training recipe |
|
|
| - Frozen base: Qwen3-4B-Instruct-2507 (36 layers, hidden 2560, GQA 32:8). |
| - Data: WikiText-103 raw, 4K-token sequences (packing was on at training |
| time; default in the repo is now off β see Caveats). |
| - 2000 steps, batch 8, lr 1e-4 (cosine, 100-step warmup), AdamW. |
| - `Ξ±=Ξ²=1` (contrastive + KL distillation, both layers averaged). |
| - bf16 weights, fp32 loss math. |
| - SDPA attention (B200, no flash-attn package needed). |
| - Liger fused RMSNorm/SwiGLU/RoPE on the frozen base. |
| - Total wall-clock: ~25 min on a single B200. |
|
|
| ## License |
|
|
| The search projections are released under Apache-2.0 (matching the base model). |
|
|