ann-sparseattention

Search projections for ANN-substituted attention on Qwen/Qwen3-4B-Instruct-2507.

Code: github.com/unixsysdev/ann-sparseattention

Current status

Research prototype. Trained projections work, runtime is a correctness prototype, eval envelope is narrow. Treat reported numbers as preliminary.

Validated: 6-layer pilot on Qwen3-4B-Instruct-2507; WikiText-103 PPL preserved at K=128 (gap β‰ˆ +0.7%); learned projections retrieve attention- relevant keys.

Not yet validated: 34-layer / whole-model substitution; long-context tasks (LongBench, RULER, needle); wall-clock speedup vs FlashAttention/SDPA; KV-cache decode-mode integration; GPU-resident ANN kernel.

Runtime caveat: the FAISS path here builds CPU indexes per batch and the gather step uses dense-style tensor expansion. Compute-reduction numbers below are algorithmic scoring reductions, not measured wall-clock speedups.

Relation to RetrievalAttention

RetrievalAttention (Liu et al., 2024) shows that vanilla ANN over the model's native Q, K vectors fails because Q and K live in mismatched distributions β€” they were never trained to be each other's nearest neighbors, only to score via dot product. Their fix is at index time: an attention-aware graph construction (RoarGraph-style).

This work attacks the same problem from the opposite direction. We train a tiny shared projection (W_Qs, W_Ks β†’ R^64) so that q_search and k_search live in the same distribution by construction. Off-the-shelf FAISS HNSW with default parameters then suffices.

Search space Index Trainable
Raw Q/K + vanilla ANN original Q/K off-the-shelf no β€” fails (Q/K OOD)
RetrievalAttention original Q/K attention-aware graph no
This work learned Q_s / K_s off-the-shelf yes (~2-11M params)

Contribution: eliminate Q/K mismatch at index-build time via distillation, instead of patching it at search time. The clean validating experiment β€” vanilla FAISS over raw Q/K vs. learned Q_s/K_s vs. exact teacher top-K β€” is the next planned run.

What's in this repo

Per-layer linear search projections (W_Qs, W_Ks) of shape [2560, 64], trained against the frozen base model's attention via contrastive + distillation losses. At inference these produce 64-d "search vectors" that let an off-the-shelf FAISS HNSW index pick the top-K keys to attend to, replacing dense O(LΒ²) attention with O(LΒ·K) ANN-substituted attention.

Layers covered (pilot): [4, 8, 12, 16, 20, 24] β€” 6 of 36 layers, ~2M trainable params.

Pilot results (final, 2K steps on WikiText-103)

Step Recall@K=128 PPL gap (full vs ANN)
500 47.4% 1.21%
1000 50.7% 0.68%
1500 50.9% 0.68%
2000 (final) 50.9% 0.71%

PPL gap is the primary signal β€” at <1% relative gap, the model's output quality is preserved under ANN substitution. Recall plateaus around step 1000 because the softmax-relevant keys concentrate in the top ~30; disagreement on positions 30-128 is on near-zero-weight tail and doesn't affect output.

K-retrieve Pareto (pilot step 2000, FAISS HNSW)

PPL_full = 9.958

K Recall@K PPL_ANN PPL gap
16 24.9% 10.71 +7.51%
32 22.8% 10.41 +4.51%
64 23.1% 10.20 +2.42%
128 26.0% 10.04 +0.82%
256 31.6% 9.88 βˆ’0.79%
512 40.8% 9.67 βˆ’2.89%

On this small WikiText slice, K β‰₯ 256 produced lower measured PPL than the full-attention reference. A plausible explanation is sparse-softmax denoising, but with 12 eval batches, sample noise, packed-boundary artifacts (pilot trained with packing on; default in the repo is now off), and partial-layer substitution acting like regularization are also candidates. Treating it as a hypothesis to confirm via an exact-topK oracle (full QK^T β†’ top-K β†’ restricted attention) at the same K β€” that separates "denoising from any sparsity" from "denoising from learned projections."

Code-level sanity checks pass: same input sequences for ppl_full vs ppl_ann, intact causal mask in retrieval, single softmax over retrieved K with no wrapper leakage between iterations.

Compute / quality knobs (FLOP-counted)

L = 4096. Compute reduction is the attention scoring step, β‰ˆ L / K. These are FLOP estimates, not measured wall-clock β€” the FAISS path in this repo is a research prototype that does CPU index builds and GPU↔CPU transfers, so it is not the right thing to time.

K PPL gap Attention scoring reduction
512 βˆ’2.89% ~8Γ—
256 βˆ’0.79% ~16Γ—
128 +0.82% ~32Γ—
64 +2.42% ~64Γ—
32 +4.51% ~128Γ—
16 +7.51% ~256Γ—

Eval scope: 12 sequences Γ— 4K tokens of WikiText-103 validation (~50K tokens). Read these as "what we observed on this slice", not population- level estimates.

The K-sweep recall numbers (24–41%) and the in-training evaluate() recall (50.9% at K=128) come from different sampled subsets of the streaming split and shouldn't be directly compared. The repo also reports mass@K (sum of teacher attention probability captured by the search top-K) β€” that's the more direct retrieval-quality metric when softmax is sharp.

Per-layer recall (pilot)

Layer Recall@K=128 Recall@K=512
4 15.8% 34.7%
8 22.2% 38.7%
12 23.4% 39.1%
16 31.9% 45.2%
20 31.4% 42.6%
24 31.1% 44.4%

Early layers are harder for content-addressable retrieval β€” their attention is more local/positional than semantic. Consistent across K, so it's a property of the layer rather than noise.

Caveats / what's next

  • Packing: pilot training and eval ran with sequence packing on (no segment-level causal mask, since transformers' default forward doesn't build them). The relative PPL gap between full and ANN is internally consistent under this confound, but the negative gap at Kβ‰₯256 has at least three candidate explanations we haven't disentangled β€” (a) sparse-softmax denoising, (b) ANN happening to filter cross-document keys that full attention attends to, (c) sample noise on a small eval. The default config now has packing off so the next run isolates (a).
  • Exact-topK oracle: a four-way Pareto (full vs. exact top-K vs. search-topK exact vs. search-ANN) is the natural follow-up to separate "denoising from any sparsity" from "denoising from learned projections."
  • Wall-clock: not measured. The FAISS path in the repo is a CPU-side research prototype, not a deployable runtime. A GPU-resident topk kernel is the next-step engineering.
  • 34-layer headline was queued (make_headline_config() is wired) and will mirror its checkpoints here when it runs.

Files

File What
search_step_1000.pt Mid-training checkpoint (step 1000, 0.68% PPL gap)
search_step_2000.pt Final pilot checkpoint (step 2000, 0.71% PPL gap)

Each contains {step, search_module: state_dict, optimizer, scheduler, config}.

Loading

import torch
from transformers import AutoModelForCausalLM
# Search module class is in the GitHub repo (model.py)
from model import SearchProjectionModule

base = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-4B-Instruct-2507",
    dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa",
)

search = SearchProjectionModule(
    d_model=2560, d_search=64,
    layer_indices=[4, 8, 12, 16, 20, 24],
    use_mlp=False,
).to(base.device).to(torch.bfloat16)

ckpt = torch.load("search_step_2000.pt", map_location="cpu", weights_only=False)
search.load_state_dict(ckpt["search_module"])

Use inference.install_ann_attention(...) (in the GitHub repo) to monkey-patch the trained layers and run with FAISS HNSW retrieval at inference time.

Training recipe

  • Frozen base: Qwen3-4B-Instruct-2507 (36 layers, hidden 2560, GQA 32:8).
  • Data: WikiText-103 raw, 4K-token sequences (packing was on at training time; default in the repo is now off β€” see Caveats).
  • 2000 steps, batch 8, lr 1e-4 (cosine, 100-step warmup), AdamW.
  • Ξ±=Ξ²=1 (contrastive + KL distillation, both layers averaged).
  • bf16 weights, fp32 loss math.
  • SDPA attention (B200, no flash-attn package needed).
  • Liger fused RMSNorm/SwiGLU/RoPE on the frozen base.
  • Total wall-clock: ~25 min on a single B200.

License

The search projections are released under Apache-2.0 (matching the base model).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for datasysdev/ann-sparseattention

Finetuned
(1669)
this model