--- language: - en license: apache-2.0 base_model: Qwen/Qwen3-4B-Instruct-2507 tags: - sparse-attention - ann-attention - distillation - search-projection - inference-optimization library_name: pytorch --- # ann-sparseattention Search projections for ANN-substituted attention on [`Qwen/Qwen3-4B-Instruct-2507`](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507). Code: [github.com/unixsysdev/ann-sparseattention](https://github.com/unixsysdev/ann-sparseattention) ## Current status Research prototype. Trained projections work, runtime is a correctness prototype, eval envelope is narrow. Treat reported numbers as preliminary. **Validated:** 6-layer pilot on Qwen3-4B-Instruct-2507; WikiText-103 PPL preserved at K=128 (gap ≈ +0.7%); learned projections retrieve attention- relevant keys. **Not yet validated:** 34-layer / whole-model substitution; long-context tasks (LongBench, RULER, needle); wall-clock speedup vs FlashAttention/SDPA; KV-cache decode-mode integration; GPU-resident ANN kernel. **Runtime caveat:** the FAISS path here builds CPU indexes per batch and the gather step uses dense-style tensor expansion. Compute-reduction numbers below are *algorithmic scoring reductions, not measured wall-clock speedups.* ## Relation to RetrievalAttention RetrievalAttention (Liu et al., 2024) shows that **vanilla ANN over the model's native Q, K vectors fails** because Q and K live in mismatched distributions — they were never trained to be each other's nearest neighbors, only to score via dot product. Their fix is at *index time*: an attention-aware graph construction (RoarGraph-style). This work attacks the same problem from the opposite direction. We **train a tiny shared projection** (`W_Qs, W_Ks → R^64`) so that `q_search` and `k_search` live in the same distribution by construction. Off-the-shelf FAISS HNSW with default parameters then suffices. | | Search space | Index | Trainable | |---|---|---|---| | Raw Q/K + vanilla ANN | original Q/K | off-the-shelf | no — fails (Q/K OOD) | | RetrievalAttention | original Q/K | attention-aware graph | no | | **This work** | **learned Q\_s / K\_s** | **off-the-shelf** | **yes (~2-11M params)** | Contribution: *eliminate Q/K mismatch at index-build time via distillation, instead of patching it at search time.* The clean validating experiment — vanilla FAISS over raw Q/K vs. learned Q\_s/K\_s vs. exact teacher top-K — is the next planned run. ## What's in this repo Per-layer linear search projections `(W_Qs, W_Ks)` of shape `[2560, 64]`, trained against the frozen base model's attention via contrastive + distillation losses. At inference these produce 64-d "search vectors" that let an off-the-shelf FAISS HNSW index pick the top-K keys to attend to, replacing dense `O(L²)` attention with `O(L·K)` ANN-substituted attention. Layers covered (pilot): `[4, 8, 12, 16, 20, 24]` — 6 of 36 layers, ~2M trainable params. ## Pilot results (final, 2K steps on WikiText-103) | Step | Recall@K=128 | PPL gap (full vs ANN) | |---|---|---| | 500 | 47.4% | 1.21% | | 1000 | 50.7% | 0.68% | | 1500 | 50.9% | 0.68% | | **2000 (final)** | **50.9%** | **0.71%** | PPL gap is the primary signal — at <1% relative gap, the model's output quality is preserved under ANN substitution. Recall plateaus around step 1000 because the softmax-relevant keys concentrate in the top ~30; disagreement on positions 30-128 is on near-zero-weight tail and doesn't affect output. ### K-retrieve Pareto (pilot step 2000, FAISS HNSW) `PPL_full = 9.958` | K | Recall@K | PPL_ANN | PPL gap | |---|---|---|---| | 16 | 24.9% | 10.71 | +7.51% | | 32 | 22.8% | 10.41 | +4.51% | | 64 | 23.1% | 10.20 | +2.42% | | 128 | 26.0% | 10.04 | +0.82% | | 256 | 31.6% | 9.88 | **−0.79%** | | 512 | 40.8% | 9.67 | **−2.89%** | On this small WikiText slice, K ≥ 256 produced lower measured PPL than the full-attention reference. A plausible explanation is sparse-softmax denoising, but with 12 eval batches, sample noise, packed-boundary artifacts (pilot trained with packing on; default in the repo is now off), and partial-layer substitution acting like regularization are also candidates. Treating it as a hypothesis to confirm via an exact-topK oracle (full QK^T → top-K → restricted attention) at the same K — that separates "denoising from any sparsity" from "denoising from learned projections." Code-level sanity checks pass: same input sequences for `ppl_full` vs `ppl_ann`, intact causal mask in retrieval, single softmax over retrieved K with no wrapper leakage between iterations. ### Compute / quality knobs (FLOP-counted) `L = 4096`. Compute reduction is the attention scoring step, ≈ `L / K`. These are FLOP estimates, not measured wall-clock — the FAISS path in this repo is a research prototype that does CPU index builds and GPU↔CPU transfers, so it is not the right thing to time. | K | PPL gap | Attention scoring reduction | |---|---|---| | 512 | −2.89% | ~8× | | 256 | −0.79% | ~16× | | 128 | +0.82% | ~32× | | 64 | +2.42% | ~64× | | 32 | +4.51% | ~128× | | 16 | +7.51% | ~256× | Eval scope: 12 sequences × 4K tokens of WikiText-103 validation (~50K tokens). Read these as "what we observed on this slice", not population- level estimates. The K-sweep recall numbers (24–41%) and the in-training `evaluate()` recall (50.9% at K=128) come from different sampled subsets of the streaming split and shouldn't be directly compared. The repo also reports `mass@K` (sum of teacher attention probability captured by the search top-K) — that's the more direct retrieval-quality metric when softmax is sharp. ### Per-layer recall (pilot) | Layer | Recall@K=128 | Recall@K=512 | |---|---|---| | 4 | 15.8% | 34.7% | | 8 | 22.2% | 38.7% | | 12 | 23.4% | 39.1% | | 16 | 31.9% | 45.2% | | 20 | 31.4% | 42.6% | | 24 | 31.1% | 44.4% | Early layers are harder for content-addressable retrieval — their attention is more local/positional than semantic. Consistent across K, so it's a property of the layer rather than noise. ### Caveats / what's next - **Packing**: pilot training and eval ran with sequence packing on (no segment-level causal mask, since transformers' default forward doesn't build them). The relative PPL gap between full and ANN is internally consistent under this confound, but the negative gap at K≥256 has at least three candidate explanations we haven't disentangled — (a) sparse-softmax denoising, (b) ANN happening to filter cross-document keys that full attention attends to, (c) sample noise on a small eval. The default config now has packing off so the next run isolates (a). - **Exact-topK oracle**: a four-way Pareto (full vs. exact top-K vs. search-topK exact vs. search-ANN) is the natural follow-up to separate "denoising from any sparsity" from "denoising from learned projections." - **Wall-clock**: not measured. The FAISS path in the repo is a CPU-side research prototype, not a deployable runtime. A GPU-resident topk kernel is the next-step engineering. - **34-layer headline** was queued (`make_headline_config()` is wired) and will mirror its checkpoints here when it runs. ## Files | File | What | |---|---| | `search_step_1000.pt` | Mid-training checkpoint (step 1000, 0.68% PPL gap) | | `search_step_2000.pt` | Final pilot checkpoint (step 2000, 0.71% PPL gap) | Each contains `{step, search_module: state_dict, optimizer, scheduler, config}`. ## Loading ```python import torch from transformers import AutoModelForCausalLM # Search module class is in the GitHub repo (model.py) from model import SearchProjectionModule base = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen3-4B-Instruct-2507", dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa", ) search = SearchProjectionModule( d_model=2560, d_search=64, layer_indices=[4, 8, 12, 16, 20, 24], use_mlp=False, ).to(base.device).to(torch.bfloat16) ckpt = torch.load("search_step_2000.pt", map_location="cpu", weights_only=False) search.load_state_dict(ckpt["search_module"]) ``` Use `inference.install_ann_attention(...)` (in the GitHub repo) to monkey-patch the trained layers and run with FAISS HNSW retrieval at inference time. ## Training recipe - Frozen base: Qwen3-4B-Instruct-2507 (36 layers, hidden 2560, GQA 32:8). - Data: WikiText-103 raw, 4K-token sequences (packing was on at training time; default in the repo is now off — see Caveats). - 2000 steps, batch 8, lr 1e-4 (cosine, 100-step warmup), AdamW. - `α=β=1` (contrastive + KL distillation, both layers averaged). - bf16 weights, fp32 loss math. - SDPA attention (B200, no flash-attn package needed). - Liger fused RMSNorm/SwiGLU/RoPE on the frozen base. - Total wall-clock: ~25 min on a single B200. ## License The search projections are released under Apache-2.0 (matching the base model).