File size: 8,861 Bytes
01d6e75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720eddc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66efa56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01d6e75
 
 
 
 
 
 
 
 
 
50122dd
01d6e75
 
 
 
 
50122dd
 
01d6e75
 
50122dd
 
 
 
8f87cd2
 
 
 
 
 
 
 
 
 
 
 
 
720eddc
 
 
 
 
 
 
 
 
 
 
 
92ec570
b42f744
92ec570
b42f744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f87cd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b42f744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01d6e75
 
 
 
 
50122dd
 
 
 
01d6e75
 
 
 
 
 
50122dd
 
01d6e75
 
 
 
 
 
 
 
 
 
 
 
 
 
50122dd
01d6e75
 
 
 
 
 
 
 
 
b42f744
 
01d6e75
 
 
 
 
50122dd
01d6e75
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
---
language:
- en
license: apache-2.0
base_model: Qwen/Qwen3-4B-Instruct-2507
tags:
- sparse-attention
- ann-attention
- distillation
- search-projection
- inference-optimization
library_name: pytorch
---

# ann-sparseattention

Search projections for ANN-substituted attention on
[`Qwen/Qwen3-4B-Instruct-2507`](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507).

Code: [github.com/unixsysdev/ann-sparseattention](https://github.com/unixsysdev/ann-sparseattention)

## Current status

Research prototype. Trained projections work, runtime is a correctness
prototype, eval envelope is narrow. Treat reported numbers as preliminary.

**Validated:** 6-layer pilot on Qwen3-4B-Instruct-2507; WikiText-103 PPL
preserved at K=128 (gap β‰ˆ +0.7%); learned projections retrieve attention-
relevant keys.

**Not yet validated:** 34-layer / whole-model substitution; long-context
tasks (LongBench, RULER, needle); wall-clock speedup vs FlashAttention/SDPA;
KV-cache decode-mode integration; GPU-resident ANN kernel.

**Runtime caveat:** the FAISS path here builds CPU indexes per batch and
the gather step uses dense-style tensor expansion. Compute-reduction
numbers below are *algorithmic scoring reductions, not measured wall-clock
speedups.*

## Relation to RetrievalAttention

RetrievalAttention (Liu et al., 2024) shows that **vanilla ANN over the
model's native Q, K vectors fails** because Q and K live in mismatched
distributions β€” they were never trained to be each other's nearest
neighbors, only to score via dot product. Their fix is at *index time*:
an attention-aware graph construction (RoarGraph-style).

This work attacks the same problem from the opposite direction. We
**train a tiny shared projection** (`W_Qs, W_Ks β†’ R^64`) so that
`q_search` and `k_search` live in the same distribution by construction.
Off-the-shelf FAISS HNSW with default parameters then suffices.

| | Search space | Index | Trainable |
|---|---|---|---|
| Raw Q/K + vanilla ANN | original Q/K | off-the-shelf | no β€” fails (Q/K OOD) |
| RetrievalAttention | original Q/K | attention-aware graph | no |
| **This work** | **learned Q\_s / K\_s** | **off-the-shelf** | **yes (~2-11M params)** |

Contribution: *eliminate Q/K mismatch at index-build time via distillation,
instead of patching it at search time.* The clean validating experiment β€”
vanilla FAISS over raw Q/K vs. learned Q\_s/K\_s vs. exact teacher top-K
β€” is the next planned run.

## What's in this repo

Per-layer linear search projections `(W_Qs, W_Ks)` of shape `[2560, 64]`,
trained against the frozen base model's attention via contrastive +
distillation losses. At inference these produce 64-d "search vectors" that
let an off-the-shelf FAISS HNSW index pick the top-K keys to attend to,
replacing dense `O(LΒ²)` attention with `O(LΒ·K)` ANN-substituted attention.

Layers covered (pilot): `[4, 8, 12, 16, 20, 24]` β€” 6 of 36 layers, ~2M trainable params.

## Pilot results (final, 2K steps on WikiText-103)

| Step | Recall@K=128 | PPL gap (full vs ANN) |
|---|---|---|
| 500 | 47.4% | 1.21% |
| 1000 | 50.7% | 0.68% |
| 1500 | 50.9% | 0.68% |
| **2000 (final)** | **50.9%** | **0.71%** |

PPL gap is the primary signal β€” at <1% relative gap, the model's output
quality is preserved under ANN substitution. Recall plateaus around step 1000
because the softmax-relevant keys concentrate in the top ~30; disagreement
on positions 30-128 is on near-zero-weight tail and doesn't affect output.

### K-retrieve Pareto (pilot step 2000, FAISS HNSW)

`PPL_full = 9.958`

| K | Recall@K | PPL_ANN | PPL gap |
|---|---|---|---|
| 16 | 24.9% | 10.71 | +7.51% |
| 32 | 22.8% | 10.41 | +4.51% |
| 64 | 23.1% | 10.20 | +2.42% |
| 128 | 26.0% | 10.04 | +0.82% |
| 256 | 31.6% | 9.88 | **βˆ’0.79%** |
| 512 | 40.8% | 9.67 | **βˆ’2.89%** |

On this small WikiText slice, K β‰₯ 256 produced lower measured PPL than
the full-attention reference. A plausible explanation is sparse-softmax
denoising, but with 12 eval batches, sample noise, packed-boundary artifacts
(pilot trained with packing on; default in the repo is now off), and
partial-layer substitution acting like regularization are also candidates.
Treating it as a hypothesis to confirm via an exact-topK oracle (full QK^T
β†’ top-K β†’ restricted attention) at the same K β€” that separates "denoising
from any sparsity" from "denoising from learned projections."

Code-level sanity checks pass: same input sequences for `ppl_full` vs
`ppl_ann`, intact causal mask in retrieval, single softmax over retrieved
K with no wrapper leakage between iterations.

### Compute / quality knobs (FLOP-counted)

`L = 4096`. Compute reduction is the attention scoring step, β‰ˆ `L / K`.
These are FLOP estimates, not measured wall-clock β€” the FAISS path in this
repo is a research prototype that does CPU index builds and GPU↔CPU
transfers, so it is not the right thing to time.

| K | PPL gap | Attention scoring reduction |
|---|---|---|
| 512 | βˆ’2.89% | ~8Γ— |
| 256 | βˆ’0.79% | ~16Γ— |
| 128 | +0.82% | ~32Γ— |
| 64 | +2.42% | ~64Γ— |
| 32 | +4.51% | ~128Γ— |
| 16 | +7.51% | ~256Γ— |

Eval scope: 12 sequences Γ— 4K tokens of WikiText-103 validation (~50K
tokens). Read these as "what we observed on this slice", not population-
level estimates.

The K-sweep recall numbers (24–41%) and the in-training `evaluate()` recall
(50.9% at K=128) come from different sampled subsets of the streaming split
and shouldn't be directly compared. The repo also reports `mass@K` (sum of
teacher attention probability captured by the search top-K) β€” that's the
more direct retrieval-quality metric when softmax is sharp.

### Per-layer recall (pilot)

| Layer | Recall@K=128 | Recall@K=512 |
|---|---|---|
| 4 | 15.8% | 34.7% |
| 8 | 22.2% | 38.7% |
| 12 | 23.4% | 39.1% |
| 16 | 31.9% | 45.2% |
| 20 | 31.4% | 42.6% |
| 24 | 31.1% | 44.4% |

Early layers are harder for content-addressable retrieval β€” their attention
is more local/positional than semantic. Consistent across K, so it's a
property of the layer rather than noise.

### Caveats / what's next

- **Packing**: pilot training and eval ran with sequence packing on (no
  segment-level causal mask, since transformers' default forward doesn't
  build them). The relative PPL gap between full and ANN is internally
  consistent under this confound, but the negative gap at Kβ‰₯256 has at
  least three candidate explanations we haven't disentangled β€”
  (a) sparse-softmax denoising, (b) ANN happening to filter cross-document
  keys that full attention attends to, (c) sample noise on a small eval.
  The default config now has packing off so the next run isolates (a).
- **Exact-topK oracle**: a four-way Pareto (full vs. exact top-K vs.
  search-topK exact vs. search-ANN) is the natural follow-up to separate
  "denoising from any sparsity" from "denoising from learned projections."
- **Wall-clock**: not measured. The FAISS path in the repo is a CPU-side
  research prototype, not a deployable runtime. A GPU-resident topk kernel
  is the next-step engineering.
- **34-layer headline** was queued (`make_headline_config()` is wired) and
  will mirror its checkpoints here when it runs.

## Files

| File | What |
|---|---|
| `search_step_1000.pt` | Mid-training checkpoint (step 1000, 0.68% PPL gap) |
| `search_step_2000.pt` | Final pilot checkpoint (step 2000, 0.71% PPL gap) |

Each contains `{step, search_module: state_dict, optimizer, scheduler, config}`.

## Loading

```python
import torch
from transformers import AutoModelForCausalLM
# Search module class is in the GitHub repo (model.py)
from model import SearchProjectionModule

base = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-4B-Instruct-2507",
    dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa",
)

search = SearchProjectionModule(
    d_model=2560, d_search=64,
    layer_indices=[4, 8, 12, 16, 20, 24],
    use_mlp=False,
).to(base.device).to(torch.bfloat16)

ckpt = torch.load("search_step_2000.pt", map_location="cpu", weights_only=False)
search.load_state_dict(ckpt["search_module"])
```

Use `inference.install_ann_attention(...)` (in the GitHub repo) to monkey-patch
the trained layers and run with FAISS HNSW retrieval at inference time.

## Training recipe

- Frozen base: Qwen3-4B-Instruct-2507 (36 layers, hidden 2560, GQA 32:8).
- Data: WikiText-103 raw, 4K-token sequences (packing was on at training
  time; default in the repo is now off β€” see Caveats).
- 2000 steps, batch 8, lr 1e-4 (cosine, 100-step warmup), AdamW.
- `Ξ±=Ξ²=1` (contrastive + KL distillation, both layers averaged).
- bf16 weights, fp32 loss math.
- SDPA attention (B200, no flash-attn package needed).
- Liger fused RMSNorm/SwiGLU/RoPE on the frozen base.
- Total wall-clock: ~25 min on a single B200.

## License

The search projections are released under Apache-2.0 (matching the base model).