libertywing commited on
Commit
9437df5
·
1 Parent(s): 9dcfd27

Streamline model card

Browse files
Files changed (1) hide show
  1. README.md +38 -284
README.md CHANGED
@@ -6,7 +6,6 @@ tags:
6
  - retrieval
7
  - kv-cache
8
  - sparse-attention
9
- - compress-sparse-attention
10
  - long-context
11
  - flashmemory
12
  datasets:
@@ -18,321 +17,76 @@ datasets:
18
 
19
  # FlashMemory DS-V4 Retriever
20
 
21
- A standalone, dependency-light reference implementation of the **FlashMemory DS-V4
22
- Retriever** a lightweight retriever that sparsifies the **DeepSeek-V4
23
- Compressed-Sparse-Attention (CSA)** KV cache.
24
 
25
- Given the hidden state of a decode token, the retriever predicts which CSA
26
- KV-cache chunks (compressed keys) the upcoming tokens will attend to, so that
27
- only the **top-scoring chunks** need to stay resident on the GPU and the rest can
28
- be offloaded to CPU / disk. This recovers most of the quality of full attention
29
- on long-context tasks while keeping a small fraction of the KV cache on-device.
30
 
31
- This release contains the **algorithm + weights + a minimal, runnable PyTorch
32
- demo**. It depends only on `torch` (plus `numpy` / `safetensors` for convenience).
33
-
34
- > **Scope note.** The full sglang serving integration — KV-cache swap-in/out,
35
- > attention-sink, threshold fallback, per-request retriever routing — is **not**
36
- > included here, because it is tightly coupled to the internal DeepSeek-V4 CSA
37
- > framework and cannot run outside it. This repository provides the retriever
38
- > **algorithm reference implementation and trained weights only.**
39
-
40
- ---
41
-
42
- ## Model architecture
43
-
44
- The retriever scores each compressed-K chunk against the decode token's hidden
45
- state. For a single CSA layer:
46
-
47
- ```
48
- hidden [B, 4096]
49
- → wq_a (4096 → Q_LORA_RANK)
50
- → RMSNorm (q_norm_weight, eps=1e-6)
51
- → wq_b (Q_LORA_RANK → N_HEADS * HEAD_DIM)
52
- → reshape [B, N_HEADS, HEAD_DIM]
53
- → RoPE (YaRN, applied to the last ROPE_DIM=64 dims, base=160000)
54
- → Hadamard (normalized Walsh-Hadamard transform)
55
- → q [B, N_HEADS, HEAD_DIM]
56
-
57
- hidden [B, 4096]
58
- → weights_proj (4096 → N_HEADS)
59
- → × weight_scale (= HEAD_DIM^-0.5 * N_HEADS^-0.5)
60
- → fused_w [B, N_HEADS]
61
-
62
- compressed_k [B, N, HEAD_DIM + 4] (uint8)
63
- → bytes[:HEAD_DIM] viewed as float8_e4m3 → dequantize
64
- → bytes[HEAD_DIM:] viewed as float32 → per-chunk scale
65
- → k [B, N, HEAD_DIM]
66
-
67
- score_per_head = relu( einsum('bnd,bhd->bnh', k, q) ) # [B, N, N_HEADS]
68
- logit = (score_per_head * fused_w[:, None, :]).sum(-1) # [B, N]
69
- score = sigmoid(logit) ∈ [0, 1] # [B, N]
70
- ```
71
-
72
- **Hyperparameters (FlashMemory DS-V4):** `Q_LORA_RANK = 2048`, `N_HEADS = 128`,
73
- `HEAD_DIM = 128`, `ROPE_DIM = 64`, `ROPE_BASE = 160000`, `ROPE_FACTOR = 16`,
74
- `ROPE_ORIGINAL_SEQ_LEN = 65536`, `ROPE_BETA_FAST = 32`, `ROPE_BETA_SLOW = 1`,
75
- `RMS_NORM_EPS = 1e-6`.
76
-
77
- ### Joint multi-layer checkpoint + ensemble
78
-
79
- FlashMemory DS-V4 is a **joint checkpoint** holding three independent CSA layers
80
- (`l10`, `l12`, `l20`), each with its own weights. At inference time the per-layer
81
- sigmoid scores are **ensembled per chunk** — cross-layer `max` (default) or
82
- `mean` — to produce a single keep/drop decision per chunk.
83
-
84
- ---
85
-
86
- ## What is FlashMemory DS-V4?
87
-
88
- FlashMemory DS-V4 is part of the latest retraining generation of these retrievers. In the
89
- project's downstream evaluation it stays close to the full-attention baseline on
90
- long-context tasks (e.g. RULER, LongMemEval, LongBench V2) while keeping only a
91
- small fraction of the CSA KV cache on-device (≈90% KV reduction in the deployment
92
- sweet spot for reasoning-heavy long-context tasks). Precise-needle retrieval
93
- tasks need an extra threshold-fallback mechanism in the serving layer (not part
94
- of this standalone release).
95
-
96
- ---
97
-
98
- ## Installation
99
-
100
- ```bash
101
- pip install -r requirements.txt
102
- ```
103
-
104
- Only `torch` is strictly required to run the model and demo. `float8_e4m3`
105
- tensor support requires a reasonably recent PyTorch (≥ 2.1).
106
-
107
- ---
108
-
109
- ## Running the demo
110
 
111
  ```bash
 
112
  python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
113
  ```
114
 
115
- The demo builds **random mock inputs** (a batch of decode-token hidden states, a
116
- set of `uint8` compressed-K chunks, and token positions), loads the FlashMemory DS-V4
117
- checkpoint, runs the forward pass, prints the per-layer and ensembled per-chunk
118
- scores, and demonstrates both **threshold** and **top-K** chunk selection.
119
-
120
- Useful flags:
121
-
122
- | Flag | Default | Meaning |
123
- |------|---------|---------|
124
- | `--device` | `cpu` | `cpu` or `cuda` |
125
- | `--batch` | `2` | number of decode tokens |
126
- | `--n-chunks` | `64` | number of compressed-K chunks |
127
- | `--top-k` | `16` | top-K chunks to select |
128
- | `--threshold` | `0.5` | sigmoid keep threshold |
129
- | `--ensemble` | `max` | cross-layer ensemble mode (`max` / `mean`) |
130
- | `--max-position` | `524288` | RoPE table length (raise to `1048576` for 1M context) |
131
-
132
- Example output (CPU, default args):
133
-
134
- ```
135
- [demo] loaded layers=['l10', 'l12', 'l20'] n_heads=128 head_dim=128 max_position=524288
136
- [demo] per-layer sigmoid score stats (over all chunks):
137
- l10: min=0.4474 mean=0.5021 max=0.6416
138
- ...
139
- [demo] threshold selection (sigmoid > 0.5):
140
- row 0: keep 64/64 chunks (keep ratio 100.0%)
141
- row 1: keep 49/64 chunks (keep ratio 76.6%)
142
- [demo] done. ✅ forward + scoring + selection all ran.
143
- ```
144
-
145
- > The scores above come from **random mock K**, so they cluster near 0.5 — they
146
- > are only meaningful on real CSA keys. The demo's purpose is to verify the
147
- > load → forward → selection path end-to-end.
148
-
149
- ---
150
-
151
- ## Using the model in your own code
152
 
153
  ```python
154
- import torch
155
  from retriever import FlashMemoryRetriever
156
 
157
  model = FlashMemoryRetriever.from_checkpoint(
158
- "weights/flashmemory_ds_v4.safetensors", device="cuda", max_position=524288
159
  )
160
 
161
- hidden = torch.randn(B, 4096, device="cuda") # decode-token hidden states
162
- compressed_k = ... # [B, N, 132] uint8 CSA keys
163
- positions = torch.arange(B, device="cuda") # int64 token positions
164
-
165
- # Per-layer sigmoid scores: {"l10": [B, N], "l12": [B, N], "l20": [B, N]}
166
- per_layer = model(hidden, compressed_k, positions)
167
-
168
- # Cross-layer ensembled per-chunk scores [B, N] ∈ [0, 1]
169
- scores = model.ensemble(hidden, compressed_k, positions, mode="max")
170
 
171
- # Boolean keep-mask [B, N] for the chunks to keep on-device
172
- keep = model.select_topk(hidden, compressed_k, positions, top_k=512) # top-K
173
- keep = model.select_topk(hidden, compressed_k, positions, threshold=0.5) # threshold
174
  ```
175
 
176
- **`compressed_k` format.** Each chunk is `HEAD_DIM + 4 = 132` `uint8` bytes:
177
- the first `128` bytes are the `float8_e4m3` quantized key values, the last `4`
178
- bytes are a single `float32` per-chunk scale. Dequantization is
179
- `fp8_values.view(float8_e4m3).float() * scale`. See `make_mock_compressed_k` in
180
- `demo.py` for how to construct a valid tensor.
181
-
182
- ---
183
 
184
- ## Weights
185
 
186
- **Download:** [Hugging Face](https://huggingface.co/<HF_REPO>) `flashmemory_ds_v4.safetensors` (≈510 MB).
 
187
 
188
- ```bash
189
- huggingface-cli download <HF_REPO> flashmemory_ds_v4.safetensors --local-dir ./weights
190
- python demo.py --ckpt ./weights/flashmemory_ds_v4.safetensors
191
  ```
 
 
 
192
 
193
- `from_checkpoint` accepts either a `.pt` (`torch.save` state-dict) or a
194
- `.safetensors` file. The released `.safetensors` is the **slim** form: it stores
195
- only the four learned tensors per layer
196
- (`wq_a.weight`, `wq_b.weight`, `q_norm_weight`, `weights_proj.weight` for
197
- `l10` / `l12` / `l20`) and **omits the `freqs_cis` RoPE table** (≈400 MB), which
198
- is recomputed at load time from `max_position`. Loading the slim `.safetensors`
199
- is bit-for-bit identical to loading the full `.pt` (verified by output match).
200
-
201
- ---
202
-
203
- ## Files
204
-
205
- | File | Purpose |
206
- |------|---------|
207
- | `retriever.py` | `FlashMemoryRetriever` model + RoPE/Hadamard utils + FP8 dequant (torch-only, self-contained) |
208
- | `demo.py` | minimal runnable demo with mock inputs |
209
- | `toy_flashmemory_inference.py` | toy DeepSeek-V4-FlashMemory sparse-decode loop showing **how the retriever drives memory recall at inference time** (see below) |
210
- | `requirements.txt` | `torch`, `safetensors`, `numpy` |
211
- | `LICENSE` | MIT |
212
-
213
- ---
214
-
215
- ## Toy FlashMemory inference reference (`toy_flashmemory_inference.py`)
216
-
217
- `demo.py` shows a single `hidden → scores` call. `toy_flashmemory_inference.py`
218
- is the **next step up**: a tiny, fully-runnable illustration of *how the Lightning
219
- Indexer Retriever is used inside a DeepSeek-V4-FlashMemory style sparse-decode
220
- loop* to drive "memory recall".
221
-
222
- It is intentionally small and pedagogical. It depends only on `torch` and the
223
- sibling `retriever.py`, and it **reuses the real FlashMemory DS-V4 retriever verbatim** — none
224
- of the scoring math is re-implemented.
225
-
226
- ### The inference flow it demonstrates
227
-
228
- ```
229
- ┌──────────┐ compress & store ┌────────────────────────────┐
230
- │ PREFILL │ historical K/V │ CSA KV-cache (the memory) │
231
- │ (dense │ ──────────────────► │ N compressed chunks, │
232
- │ attn) │ │ each = [132] uint8 fp8-K │
233
- └────┬─────┘ └──────────────┬─────────────┘
234
- │ last hidden state │ scored every 64 steps
235
- ▼ │
236
- ┌──────────────────────── DECODE LOOP ─────────┼──────────────────────────┐
237
- │ for each decode step t: │ │
238
- │ hidden = toy_decoder.step(token, keep_mask) │ (sparse memory attn) │
239
- │ │ │
240
- │ every RETRIEVAL_INTERVAL (= 64) steps: ▼ │
241
- │ scores[N] = retriever.ensemble(hidden, compressed_k, pos) │
242
- │ keep_mask[N] = top-K (or sigmoid > threshold) of scores │
243
- │ → chunks NOT kept are masked to -inf in the next 64 decode steps │
244
- │ of memory attention (== "not recalled onto the GPU") │
245
- └──────────────────────────────────────────────────────────────────────────┘
246
  ```
247
 
248
- 1. **Prefill (dense).** A short prompt is run through dense memory attention. Its
249
- last hidden state seeds the first retrieval cycle (the indexer needs a query
250
- hidden state to score against). In a real run, prefill is also where the
251
- historical KV is compressed into the `[N, 132]` `uint8` CSA chunks.
252
- 2. **Decode loop.** Every step the toy decoder produces a `[B, 4096]` hidden state
253
- and attends over the `N` memory chunks.
254
- 3. **Retrieval cycle (every 64 steps).** The real `FlashMemoryRetriever` scores all
255
- `N` compressed-K chunks against the current decode hidden state, ensembles the
256
- per-layer (`l10`/`l12`/`l20`) sigmoid scores, and selects the chunks to keep —
257
- either **top-K** or **sigmoid > threshold**. This predicts which chunks the
258
- *next ~64 tokens* will attend to.
259
- 4. **Sparse attention.** For the next 64 steps, chunks **not** selected have their
260
- memory-attention logits set to `-inf`, so they contribute nothing.
261
-
262
- ### What the masking simulates (important)
263
-
264
- * This toy does **not** perform any real CPU↔GPU KV-cache transfer. The swap-in /
265
- swap-out machinery is part of the internal FlashMemory engineering and is **not**
266
- included in this release.
267
- * We **simulate memory recall by masking the FlashMemory Retriever's per-chunk
268
- decisions**: a chunk the retriever did not select gets its attention logit set
269
- to `-inf`. This is equivalent to *"that chunk's KV was never recalled onto the
270
- GPU, so it cannot be attended to"* — for the attention output, masking a chunk
271
- out and never loading it produce the same result.
272
- * The toy's purpose is to make the **decode-time control flow** concrete: where the
273
- retriever fires, what it consumes (decode hidden state + compressed CSA keys),
274
- what it produces (a keep/drop mask), and how that mask sparsifies the next
275
- window of decode steps.
276
-
277
- ### What it is / is NOT
278
 
279
- * **IS:** a minimal, torch-only illustration of the decode-time control flow that
280
- drives memory recall with the real FlashMemory DS-V4 retriever.
281
- * **IS NOT:** a runnable DeepSeek-V4. The "decoder" is a couple of layers of
282
- randomly-initialized toy attention/MLP whose only jobs are (a) to emit a
283
- `[B, 4096]` hidden state for the retriever and (b) to own a memory attention we
284
- can sparsify. The generated tokens are meaningless.
285
-
286
- > **The production version cannot be released.** It depends on the internal sglang
287
- > + DeepSeek-V4 CSA framework (native FP8 indexer, real compressed KV-cache,
288
- > attention-sink, threshold fallback, per-request routing, and the actual KV swap
289
- > engine). This file shows the *algorithmic role* of the retriever only.
290
-
291
- ### Run
292
 
293
  ```bash
294
  python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
295
  ```
296
 
297
- Runs on CPU by default; pass `--device cuda` for GPU.
298
-
299
- | Flag | Default | Meaning |
300
- |------|---------|---------|
301
- | `--n-chunks` | `256` | number of CSA memory chunks (the long history) |
302
- | `--steps` | `192` | decode steps to generate |
303
- | `--retrieval-interval` | `64` | run the retriever every N steps (FlashMemory default) |
304
- | `--select-mode` | `topk` | `topk` or `threshold` |
305
- | `--top-k` | `64` | chunks to recall per cycle (`select-mode=topk`) |
306
- | `--threshold` | `0.5` | sigmoid keep threshold (`select-mode=threshold`) |
307
- | `--ensemble` | `max` | cross-layer ensemble mode (`max` / `mean`) |
308
- | `--batch` | `1` | parallel decode sequences |
309
-
310
- Example output (CPU, default args — `top-K=64` out of `256` chunks):
311
-
312
- ```
313
- FlashMemory DS-V4 — toy sparse-decode loop
314
- [load] weights/flashmemory_ds_v4.safetensors
315
- [load] layers=['l10', 'l12', 'l20'] n_heads=128 head_dim=128
316
- [init] decoder: 2 layers, 8 heads | CSA memory: 256 chunks [132] uint8
317
-
318
- [decode] 192 steps, retriever every 64 steps (topk [top-K=64], ensemble=max)
319
- ------------------------------------------------------------
320
- [cycle 0] pos 8..71 | keep 25.0% (64/256) | score mean=0.4910 max=0.5445
321
- [cycle 1] pos 72..135 | keep 25.0% (64/256) | score mean=0.4910 max=0.5445
322
- ...
323
- ------------------------------------------------------------
324
- [done] 192 tokens, 3 cycles, avg keep/cycle: 25.0% → ~75% CSA KV dropped
325
- [note] Dropped chunks are masked to -inf in attention (= KV not recalled to GPU).
326
- ```
327
-
328
- > As in `demo.py`, the scores come from **random mock K** and cluster near 0.5;
329
- > they are only meaningful on real CSA keys. The toy's value is the *control flow*
330
- > — watch each retrieval cycle report how many chunks were scored, recalled, and
331
- > masked out.
332
 
333
- ---
334
 
 
 
 
 
 
 
335
 
336
  ## License
337
 
338
- MIT — see [`LICENSE`](./LICENSE).
 
6
  - retrieval
7
  - kv-cache
8
  - sparse-attention
 
9
  - long-context
10
  - flashmemory
11
  datasets:
 
17
 
18
  # FlashMemory DS-V4 Retriever
19
 
20
+ A lightweight retriever that sparsifies **DeepSeek-V4 CSA KV-cache**. Given a
21
+ decode-token hidden state, it predicts which compressed-K chunks the next
22
+ ~64 tokens will attend to — keeping only those on GPU, offloading the rest.
23
 
24
+ In downstream evaluation it matches or beats full-attention baseline on
25
+ reasoning-heavy long-context tasks (**RULER, LongMemEval, LongBench V2**)
26
+ while reducing KV-cache usage by **~85–90%**. Precise needle-retrieval tasks
27
+ require an additional threshold-fallback mechanism (not in this release).
 
28
 
29
+ ## Quick start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  ```bash
32
+ pip install torch safetensors
33
  python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
34
  ```
35
 
36
+ ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  ```python
 
39
  from retriever import FlashMemoryRetriever
40
 
41
  model = FlashMemoryRetriever.from_checkpoint(
42
+ "weights/flashmemory_ds_v4.safetensors", device="cuda"
43
  )
44
 
45
+ # hidden: [B, 4096] decode hidden state
46
+ # compressed_k: [B, N, 132] uint8 CSA keys
47
+ # positions: [B] int64 token positions
 
 
 
 
 
 
48
 
49
+ scores = model.ensemble(hidden, compressed_k, positions, mode="max") # [B, N]
50
+ keep = model.select_topk(hidden, compressed_k, positions, top_k=512) # boolean mask
 
51
  ```
52
 
53
+ **`compressed_k` format:** each chunk = 128 bytes `float8_e4m3` values + 4 bytes `float32` scale. See `make_mock_compressed_k()` in `demo.py`.
 
 
 
 
 
 
54
 
55
+ ## Architecture
56
 
57
+ 3-layer joint model (`l10`, `l12`, `l20`), 128 heads, 2048 LoRA rank. Per-layer
58
+ sigmoid scores are ensembled (`max` or `mean`) per chunk.
59
 
 
 
 
60
  ```
61
+ hidden [B,4096] → q-proj → RoPE(YaRN) → Hadamard → q [B,128,128]
62
+ → weights_proj → fused_w [B,128]
63
+ compressed_k → FP8 dequant → k [B,N,128]
64
 
65
+ score = sigmoid( Σ( relu(k @ qᵀ) · fused_w ) ) ∈ [0,1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ```
67
 
68
+ ## Toy inference reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ `toy_flashmemory_inference.py` illustrates how the retriever drives memory
71
+ recall during decode: every 64 steps it re-scores all chunks, and unselected
72
+ ones are masked from attention (equivalent to "not recalled to GPU").
 
 
 
 
 
 
 
 
 
 
73
 
74
  ```bash
75
  python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
76
  ```
77
 
78
+ > The decoder is a few toy layers with random weights — it is **not** a real
79
+ > DeepSeek-V4. The retriever, scoring math, and decode-time control flow are real.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ ## Files
82
 
83
+ | File | Purpose |
84
+ |------|---------|
85
+ | `retriever.py` | `FlashMemoryRetriever` model (torch-only, self-contained) |
86
+ | `demo.py` | minimal demo with mock inputs |
87
+ | `toy_flashmemory_inference.py` | toy sparse-decode loop |
88
+ | `weights/flashmemory_ds_v4.safetensors` | trained weights (~510 MB) |
89
 
90
  ## License
91
 
92
+ MIT