libertywing Claude Code commited on
Commit
640b654
Β·
0 Parent(s):

Initial release: FlashMemory DS-V4 Retriever

Browse files

- FlashMemoryRetriever model (retriever.py)
- Minimal demo with mock inputs (demo.py)
- Toy sparse-decode inference reference (toy_flashmemory_inference.py)
- Model weights (flashmemory_ds_v4.safetensors, ~510 MB)

Co-Authored-By: Claude Code <noreply@anthropic.com>

.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ weights/*.safetensors filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 FlashMemory Authors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FlashMemory DS-V4 Retriever
2
+
3
+ A standalone, dependency-light reference implementation of the **FlashMemory DS-V4
4
+ Retriever** β€” a lightweight retriever that sparsifies the **DeepSeek-V4
5
+ Compressed-Sparse-Attention (CSA)** KV cache.
6
+
7
+ Given the hidden state of a decode token, the retriever predicts which CSA
8
+ KV-cache chunks (compressed keys) the upcoming tokens will attend to, so that
9
+ only the **top-scoring chunks** need to stay resident on the GPU and the rest can
10
+ be offloaded to CPU / disk. This recovers most of the quality of full attention
11
+ on long-context tasks while keeping a small fraction of the KV cache on-device.
12
+
13
+ This release contains the **algorithm + weights + a minimal, runnable PyTorch
14
+ demo**. It depends only on `torch` (plus `numpy` / `safetensors` for convenience).
15
+
16
+ > **Scope note.** The full sglang serving integration β€” KV-cache swap-in/out,
17
+ > attention-sink, threshold fallback, per-request retriever routing β€” is **not**
18
+ > included here, because it is tightly coupled to the internal DeepSeek-V4 CSA
19
+ > framework and cannot run outside it. This repository provides the retriever
20
+ > **algorithm reference implementation and trained weights only.**
21
+
22
+ ---
23
+
24
+ ## Model architecture
25
+
26
+ The retriever scores each compressed-K chunk against the decode token's hidden
27
+ state. For a single CSA layer:
28
+
29
+ ```
30
+ hidden [B, 4096]
31
+ β†’ wq_a (4096 β†’ Q_LORA_RANK)
32
+ β†’ RMSNorm (q_norm_weight, eps=1e-6)
33
+ β†’ wq_b (Q_LORA_RANK β†’ N_HEADS * HEAD_DIM)
34
+ β†’ reshape [B, N_HEADS, HEAD_DIM]
35
+ β†’ RoPE (YaRN, applied to the last ROPE_DIM=64 dims, base=160000)
36
+ β†’ Hadamard (normalized Walsh-Hadamard transform)
37
+ β†’ q [B, N_HEADS, HEAD_DIM]
38
+
39
+ hidden [B, 4096]
40
+ β†’ weights_proj (4096 β†’ N_HEADS)
41
+ β†’ Γ— weight_scale (= HEAD_DIM^-0.5 * N_HEADS^-0.5)
42
+ β†’ fused_w [B, N_HEADS]
43
+
44
+ compressed_k [B, N, HEAD_DIM + 4] (uint8)
45
+ β†’ bytes[:HEAD_DIM] viewed as float8_e4m3 β†’ dequantize
46
+ β†’ bytes[HEAD_DIM:] viewed as float32 β†’ per-chunk scale
47
+ β†’ k [B, N, HEAD_DIM]
48
+
49
+ score_per_head = relu( einsum('bnd,bhd->bnh', k, q) ) # [B, N, N_HEADS]
50
+ logit = (score_per_head * fused_w[:, None, :]).sum(-1) # [B, N]
51
+ score = sigmoid(logit) ∈ [0, 1] # [B, N]
52
+ ```
53
+
54
+ **Hyperparameters (FlashMemory DS-V4):** `Q_LORA_RANK = 2048`, `N_HEADS = 128`,
55
+ `HEAD_DIM = 128`, `ROPE_DIM = 64`, `ROPE_BASE = 160000`, `ROPE_FACTOR = 16`,
56
+ `ROPE_ORIGINAL_SEQ_LEN = 65536`, `ROPE_BETA_FAST = 32`, `ROPE_BETA_SLOW = 1`,
57
+ `RMS_NORM_EPS = 1e-6`.
58
+
59
+ ### Joint multi-layer checkpoint + ensemble
60
+
61
+ FlashMemory DS-V4 is a **joint checkpoint** holding three independent CSA layers
62
+ (`l10`, `l12`, `l20`), each with its own weights. At inference time the per-layer
63
+ sigmoid scores are **ensembled per chunk** β€” cross-layer `max` (default) or
64
+ `mean` β€” to produce a single keep/drop decision per chunk.
65
+
66
+ ---
67
+
68
+ ## What is FlashMemory DS-V4?
69
+
70
+ FlashMemory DS-V4 is part of the latest retraining generation of these retrievers. In the
71
+ project's downstream evaluation it stays close to the full-attention baseline on
72
+ long-context tasks (e.g. RULER, LongMemEval, LongBench V2) while keeping only a
73
+ small fraction of the CSA KV cache on-device (β‰ˆ90% KV reduction in the deployment
74
+ sweet spot for reasoning-heavy long-context tasks). Precise-needle retrieval
75
+ tasks need an extra threshold-fallback mechanism in the serving layer (not part
76
+ of this standalone release).
77
+
78
+ ---
79
+
80
+ ## Installation
81
+
82
+ ```bash
83
+ pip install -r requirements.txt
84
+ ```
85
+
86
+ Only `torch` is strictly required to run the model and demo. `float8_e4m3`
87
+ tensor support requires a reasonably recent PyTorch (β‰₯ 2.1).
88
+
89
+ ---
90
+
91
+ ## Running the demo
92
+
93
+ ```bash
94
+ python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
95
+ ```
96
+
97
+ The demo builds **random mock inputs** (a batch of decode-token hidden states, a
98
+ set of `uint8` compressed-K chunks, and token positions), loads the FlashMemory DS-V4
99
+ checkpoint, runs the forward pass, prints the per-layer and ensembled per-chunk
100
+ scores, and demonstrates both **threshold** and **top-K** chunk selection.
101
+
102
+ Useful flags:
103
+
104
+ | Flag | Default | Meaning |
105
+ |------|---------|---------|
106
+ | `--device` | `cpu` | `cpu` or `cuda` |
107
+ | `--batch` | `2` | number of decode tokens |
108
+ | `--n-chunks` | `64` | number of compressed-K chunks |
109
+ | `--top-k` | `16` | top-K chunks to select |
110
+ | `--threshold` | `0.5` | sigmoid keep threshold |
111
+ | `--ensemble` | `max` | cross-layer ensemble mode (`max` / `mean`) |
112
+ | `--max-position` | `524288` | RoPE table length (raise to `1048576` for 1M context) |
113
+
114
+ Example output (CPU, default args):
115
+
116
+ ```
117
+ [demo] loaded layers=['l10', 'l12', 'l20'] n_heads=128 head_dim=128 max_position=524288
118
+ [demo] per-layer sigmoid score stats (over all chunks):
119
+ l10: min=0.4474 mean=0.5021 max=0.6416
120
+ ...
121
+ [demo] threshold selection (sigmoid > 0.5):
122
+ row 0: keep 64/64 chunks (keep ratio 100.0%)
123
+ row 1: keep 49/64 chunks (keep ratio 76.6%)
124
+ [demo] done. βœ… forward + scoring + selection all ran.
125
+ ```
126
+
127
+ > The scores above come from **random mock K**, so they cluster near 0.5 β€” they
128
+ > are only meaningful on real CSA keys. The demo's purpose is to verify the
129
+ > load β†’ forward β†’ selection path end-to-end.
130
+
131
+ ---
132
+
133
+ ## Using the model in your own code
134
+
135
+ ```python
136
+ import torch
137
+ from retriever import FlashMemoryRetriever
138
+
139
+ model = FlashMemoryRetriever.from_checkpoint(
140
+ "weights/flashmemory_ds_v4.safetensors", device="cuda", max_position=524288
141
+ )
142
+
143
+ hidden = torch.randn(B, 4096, device="cuda") # decode-token hidden states
144
+ compressed_k = ... # [B, N, 132] uint8 CSA keys
145
+ positions = torch.arange(B, device="cuda") # int64 token positions
146
+
147
+ # Per-layer sigmoid scores: {"l10": [B, N], "l12": [B, N], "l20": [B, N]}
148
+ per_layer = model(hidden, compressed_k, positions)
149
+
150
+ # Cross-layer ensembled per-chunk scores [B, N] ∈ [0, 1]
151
+ scores = model.ensemble(hidden, compressed_k, positions, mode="max")
152
+
153
+ # Boolean keep-mask [B, N] for the chunks to keep on-device
154
+ keep = model.select_topk(hidden, compressed_k, positions, top_k=512) # top-K
155
+ keep = model.select_topk(hidden, compressed_k, positions, threshold=0.5) # threshold
156
+ ```
157
+
158
+ **`compressed_k` format.** Each chunk is `HEAD_DIM + 4 = 132` `uint8` bytes:
159
+ the first `128` bytes are the `float8_e4m3` quantized key values, the last `4`
160
+ bytes are a single `float32` per-chunk scale. Dequantization is
161
+ `fp8_values.view(float8_e4m3).float() * scale`. See `make_mock_compressed_k` in
162
+ `demo.py` for how to construct a valid tensor.
163
+
164
+ ---
165
+
166
+ ## Weights
167
+
168
+ **Download:** [Hugging Face](https://huggingface.co/<HF_REPO>) β€” `flashmemory_ds_v4.safetensors` (β‰ˆ510 MB).
169
+
170
+ ```bash
171
+ huggingface-cli download <HF_REPO> flashmemory_ds_v4.safetensors --local-dir ./weights
172
+ python demo.py --ckpt ./weights/flashmemory_ds_v4.safetensors
173
+ ```
174
+
175
+ `from_checkpoint` accepts either a `.pt` (`torch.save` state-dict) or a
176
+ `.safetensors` file. The released `.safetensors` is the **slim** form: it stores
177
+ only the four learned tensors per layer
178
+ (`wq_a.weight`, `wq_b.weight`, `q_norm_weight`, `weights_proj.weight` for
179
+ `l10` / `l12` / `l20`) and **omits the `freqs_cis` RoPE table** (β‰ˆ400 MB), which
180
+ is recomputed at load time from `max_position`. Loading the slim `.safetensors`
181
+ is bit-for-bit identical to loading the full `.pt` (verified by output match).
182
+
183
+ ---
184
+
185
+ ## Files
186
+
187
+ | File | Purpose |
188
+ |------|---------|
189
+ | `retriever.py` | `FlashMemoryRetriever` model + RoPE/Hadamard utils + FP8 dequant (torch-only, self-contained) |
190
+ | `demo.py` | minimal runnable demo with mock inputs |
191
+ | `toy_flashmemory_inference.py` | toy DeepSeek-V4-FlashMemory sparse-decode loop showing **how the retriever drives memory recall at inference time** (see below) |
192
+ | `requirements.txt` | `torch`, `safetensors`, `numpy` |
193
+ | `LICENSE` | MIT |
194
+
195
+ ---
196
+
197
+ ## Toy FlashMemory inference reference (`toy_flashmemory_inference.py`)
198
+
199
+ `demo.py` shows a single `hidden β†’ scores` call. `toy_flashmemory_inference.py`
200
+ is the **next step up**: a tiny, fully-runnable illustration of *how the Lightning
201
+ Indexer Retriever is used inside a DeepSeek-V4-FlashMemory style sparse-decode
202
+ loop* to drive "memory recall".
203
+
204
+ It is intentionally small and pedagogical. It depends only on `torch` and the
205
+ sibling `retriever.py`, and it **reuses the real FlashMemory DS-V4 retriever verbatim** β€” none
206
+ of the scoring math is re-implemented.
207
+
208
+ ### The inference flow it demonstrates
209
+
210
+ ```
211
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” compress & store β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
212
+ β”‚ PREFILL β”‚ historical K/V β”‚ CSA KV-cache (the memory) β”‚
213
+ β”‚ (dense β”‚ ──────────────────► β”‚ N compressed chunks, β”‚
214
+ β”‚ attn) β”‚ β”‚ each = [132] uint8 fp8-K β”‚
215
+ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
216
+ β”‚ last hidden state β”‚ scored every 64 steps
217
+ β–Ό β”‚
218
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ DECODE LOOP ─────────┼──────────────────────────┐
219
+ β”‚ for each decode step t: β”‚ β”‚
220
+ β”‚ hidden = toy_decoder.step(token, keep_mask) β”‚ (sparse memory attn) β”‚
221
+ β”‚ β”‚ β”‚
222
+ β”‚ every RETRIEVAL_INTERVAL (= 64) steps: β–Ό β”‚
223
+ β”‚ scores[N] = retriever.ensemble(hidden, compressed_k, pos) β”‚
224
+ β”‚ keep_mask[N] = top-K (or sigmoid > threshold) of scores β”‚
225
+ β”‚ β†’ chunks NOT kept are masked to -inf in the next 64 decode steps β”‚
226
+ β”‚ of memory attention (== "not recalled onto the GPU") β”‚
227
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
228
+ ```
229
+
230
+ 1. **Prefill (dense).** A short prompt is run through dense memory attention. Its
231
+ last hidden state seeds the first retrieval cycle (the indexer needs a query
232
+ hidden state to score against). In a real run, prefill is also where the
233
+ historical KV is compressed into the `[N, 132]` `uint8` CSA chunks.
234
+ 2. **Decode loop.** Every step the toy decoder produces a `[B, 4096]` hidden state
235
+ and attends over the `N` memory chunks.
236
+ 3. **Retrieval cycle (every 64 steps).** The real `FlashMemoryRetriever` scores all
237
+ `N` compressed-K chunks against the current decode hidden state, ensembles the
238
+ per-layer (`l10`/`l12`/`l20`) sigmoid scores, and selects the chunks to keep β€”
239
+ either **top-K** or **sigmoid > threshold**. This predicts which chunks the
240
+ *next ~64 tokens* will attend to.
241
+ 4. **Sparse attention.** For the next 64 steps, chunks **not** selected have their
242
+ memory-attention logits set to `-inf`, so they contribute nothing.
243
+
244
+ ### What the masking simulates (important)
245
+
246
+ * This toy does **not** perform any real CPU↔GPU KV-cache transfer. The swap-in /
247
+ swap-out machinery is part of the internal FlashMemory engineering and is **not**
248
+ included in this release.
249
+ * We **simulate memory recall by masking the FlashMemory Retriever's per-chunk
250
+ decisions**: a chunk the retriever did not select gets its attention logit set
251
+ to `-inf`. This is equivalent to *"that chunk's KV was never recalled onto the
252
+ GPU, so it cannot be attended to"* β€” for the attention output, masking a chunk
253
+ out and never loading it produce the same result.
254
+ * The toy's purpose is to make the **decode-time control flow** concrete: where the
255
+ retriever fires, what it consumes (decode hidden state + compressed CSA keys),
256
+ what it produces (a keep/drop mask), and how that mask sparsifies the next
257
+ window of decode steps.
258
+
259
+ ### What it is / is NOT
260
+
261
+ * **IS:** a minimal, torch-only illustration of the decode-time control flow that
262
+ drives memory recall with the real FlashMemory DS-V4 retriever.
263
+ * **IS NOT:** a runnable DeepSeek-V4. The "decoder" is a couple of layers of
264
+ randomly-initialized toy attention/MLP whose only jobs are (a) to emit a
265
+ `[B, 4096]` hidden state for the retriever and (b) to own a memory attention we
266
+ can sparsify. The generated tokens are meaningless.
267
+
268
+ > **The production version cannot be released.** It depends on the internal sglang
269
+ > + DeepSeek-V4 CSA framework (native FP8 indexer, real compressed KV-cache,
270
+ > attention-sink, threshold fallback, per-request routing, and the actual KV swap
271
+ > engine). This file shows the *algorithmic role* of the retriever only.
272
+
273
+ ### Run
274
+
275
+ ```bash
276
+ python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
277
+ ```
278
+
279
+ Runs on CPU by default; pass `--device cuda` for GPU.
280
+
281
+ | Flag | Default | Meaning |
282
+ |------|---------|---------|
283
+ | `--n-chunks` | `256` | number of CSA memory chunks (the long history) |
284
+ | `--steps` | `192` | decode steps to generate |
285
+ | `--retrieval-interval` | `64` | run the retriever every N steps (FlashMemory default) |
286
+ | `--select-mode` | `topk` | `topk` or `threshold` |
287
+ | `--top-k` | `64` | chunks to recall per cycle (`select-mode=topk`) |
288
+ | `--threshold` | `0.5` | sigmoid keep threshold (`select-mode=threshold`) |
289
+ | `--ensemble` | `max` | cross-layer ensemble mode (`max` / `mean`) |
290
+ | `--batch` | `1` | parallel decode sequences |
291
+
292
+ Example output (CPU, default args β€” `top-K=64` out of `256` chunks):
293
+
294
+ ```
295
+ FlashMemory DS-V4 β€” toy sparse-decode loop
296
+ [load] weights/flashmemory_ds_v4.safetensors
297
+ [load] layers=['l10', 'l12', 'l20'] n_heads=128 head_dim=128
298
+ [init] decoder: 2 layers, 8 heads | CSA memory: 256 chunks [132] uint8
299
+
300
+ [decode] 192 steps, retriever every 64 steps (topk [top-K=64], ensemble=max)
301
+ ------------------------------------------------------------
302
+ [cycle 0] pos 8..71 | keep 25.0% (64/256) | score mean=0.4910 max=0.5445
303
+ [cycle 1] pos 72..135 | keep 25.0% (64/256) | score mean=0.4910 max=0.5445
304
+ ...
305
+ ------------------------------------------------------------
306
+ [done] 192 tokens, 3 cycles, avg keep/cycle: 25.0% β†’ ~75% CSA KV dropped
307
+ [note] Dropped chunks are masked to -inf in attention (= KV not recalled to GPU).
308
+ ```
309
+
310
+ > As in `demo.py`, the scores come from **random mock K** and cluster near 0.5;
311
+ > they are only meaningful on real CSA keys. The toy's value is the *control flow*
312
+ > β€” watch each retrieval cycle report how many chunks were scored, recalled, and
313
+ > masked out.
314
+
315
+ ---
316
+
317
+
318
+ ## License
319
+
320
+ MIT β€” see [`LICENSE`](./LICENSE).
demo.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ demo.py β€” minimal standalone demo for the FlashMemory DS-V4 Retriever
3
+ =====================================================================
4
+
5
+ Builds random mock inputs, loads the FlashMemory DS-V4 joint checkpoint, runs
6
+ a forward pass, and prints per-chunk scores plus a top-K selection summary.
7
+
8
+ Run::
9
+
10
+ python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
11
+
12
+ Runs on CPU by default; pass ``--device cuda`` to use a GPU.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+
19
+ import torch
20
+
21
+ from retriever import FlashMemoryRetriever, dequant_compressed_k
22
+
23
+
24
+ def make_mock_compressed_k(
25
+ batch: int,
26
+ n_chunks: int,
27
+ head_dim: int = 128,
28
+ device: str = "cpu",
29
+ seed: int = 0,
30
+ ) -> torch.Tensor:
31
+ """Construct a valid mock ``compressed_k`` tensor [B, N, head_dim + 4] uint8.
32
+
33
+ Layout per chunk: ``head_dim`` float8_e4m3 bytes followed by one float32 scale
34
+ (4 bytes). We build it the same way the real CSA cache stores it:
35
+ 1. sample random key vectors, cast to float8_e4m3, view as uint8;
36
+ 2. sample a small positive per-chunk scale, view its float32 as 4 uint8 bytes;
37
+ 3. concatenate along the last dim.
38
+ """
39
+ g = torch.Generator(device=device).manual_seed(seed)
40
+
41
+ # 1) fp8 key bytes
42
+ k_vals = torch.randn(batch, n_chunks, head_dim, generator=g, device=device) * 0.5
43
+ k_fp8 = k_vals.to(torch.float8_e4m3fn)
44
+ fp8_bytes = k_fp8.view(torch.uint8) # [B, N, head_dim]
45
+
46
+ # 2) float32 per-chunk scale β†’ 4 bytes
47
+ scale = (0.05 + 0.15 * torch.rand(batch, n_chunks, 1, generator=g, device=device)).float()
48
+ scale_bytes = scale.view(torch.uint8) # [B, N, 4]
49
+
50
+ compressed = torch.cat([fp8_bytes, scale_bytes], dim=-1) # [B, N, head_dim + 4]
51
+ assert compressed.shape[-1] == head_dim + 4
52
+ return compressed.contiguous()
53
+
54
+
55
+ def main():
56
+ ap = argparse.ArgumentParser(description="FlashMemory DS-V4 Retriever demo")
57
+ ap.add_argument("--ckpt", required=True, help="path to joint checkpoint (.pt)")
58
+ ap.add_argument("--device", default="cpu", help="cpu or cuda (default: cpu)")
59
+ ap.add_argument("--batch", type=int, default=2, help="number of decode tokens")
60
+ ap.add_argument("--n-chunks", type=int, default=64, help="number of compressed-K chunks")
61
+ ap.add_argument("--max-position", type=int, default=524288,
62
+ help="RoPE table length (raise to 1048576 for 1M context)")
63
+ ap.add_argument("--top-k", type=int, default=16, help="top-K chunks to select")
64
+ ap.add_argument("--threshold", type=float, default=0.5, help="sigmoid keep threshold")
65
+ ap.add_argument("--ensemble", default="max", choices=["max", "mean"],
66
+ help="cross-layer ensemble mode")
67
+ ap.add_argument("--seed", type=int, default=0)
68
+ args = ap.parse_args()
69
+
70
+ torch.manual_seed(args.seed)
71
+ device = args.device
72
+
73
+ print(f"[demo] loading checkpoint: {args.ckpt}")
74
+ model = FlashMemoryRetriever.from_checkpoint(
75
+ args.ckpt, device=device, max_position=args.max_position
76
+ )
77
+ model.eval()
78
+ print(f"[demo] loaded layers={model.layer_names} n_heads={model.n_heads} "
79
+ f"head_dim={model.head_dim} max_position={model.max_position}")
80
+
81
+ # ── Mock inputs ─────────────────────────────────────────────────────────
82
+ B, N = args.batch, args.n_chunks
83
+ hidden = torch.randn(B, 4096, device=device, dtype=torch.float32)
84
+ compressed_k = make_mock_compressed_k(B, N, head_dim=model.head_dim,
85
+ device=device, seed=args.seed)
86
+ # token positions for each decode token (arbitrary; here spaced out)
87
+ positions = torch.arange(B, device=device, dtype=torch.int64) * 1000 + 4096
88
+
89
+ print(f"\n[demo] mock inputs: hidden={tuple(hidden.shape)} "
90
+ f"compressed_k={tuple(compressed_k.shape)} ({compressed_k.dtype}) "
91
+ f"positions={positions.tolist()}")
92
+
93
+ # sanity: show dequant works
94
+ k_float = dequant_compressed_k(compressed_k, head_dim=model.head_dim)
95
+ print(f"[demo] dequantized K: shape={tuple(k_float.shape)} "
96
+ f"mean={k_float.mean().item():+.4f} std={k_float.std().item():.4f}")
97
+
98
+ # ── Per-layer scores ──────────────────────────────────────────────────────
99
+ per_layer = model(hidden, compressed_k, positions, apply_sigmoid=True)
100
+ print("\n[demo] per-layer sigmoid score stats (over all chunks):")
101
+ for name, s in per_layer.items():
102
+ print(f" {name}: min={s.min().item():.4f} mean={s.mean().item():.4f} "
103
+ f"max={s.max().item():.4f}")
104
+
105
+ # ── Cross-layer ensemble ──────────────────────────────────────────────────
106
+ scores = model.ensemble(hidden, compressed_k, positions, mode=args.ensemble) # [B, N]
107
+ print(f"\n[demo] ensembled ({args.ensemble}) per-chunk scores [B={B}, N={N}]:")
108
+ for b in range(B):
109
+ row = scores[b]
110
+ preview = ", ".join(f"{v:.3f}" for v in row[:12].tolist())
111
+ print(f" row {b}: [{preview}{', ...' if N > 12 else ''}]")
112
+
113
+ # ── Selection: threshold ──────────────────────────────────────────────────
114
+ keep_thr = model.select_topk(hidden, compressed_k, positions,
115
+ threshold=args.threshold, mode=args.ensemble)
116
+ print(f"\n[demo] threshold selection (sigmoid > {args.threshold}):")
117
+ for b in range(B):
118
+ n_keep = int(keep_thr[b].sum().item())
119
+ print(f" row {b}: keep {n_keep}/{N} chunks (keep ratio {n_keep / N:.1%})")
120
+
121
+ # ── Selection: top-K ──────────────────────────────────────────────────────
122
+ keep_topk = model.select_topk(hidden, compressed_k, positions,
123
+ top_k=args.top_k, mode=args.ensemble)
124
+ print(f"\n[demo] top-K selection (k={args.top_k}):")
125
+ for b in range(B):
126
+ idx = keep_topk[b].nonzero(as_tuple=True)[0].tolist()
127
+ print(f" row {b}: kept chunk indices = {idx}")
128
+
129
+ print("\n[demo] done. βœ… forward + scoring + selection all ran.")
130
+
131
+
132
+ if __name__ == "__main__":
133
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=2.1
2
+ safetensors
3
+ numpy
retriever.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ retriever.py β€” FlashMemory DS-V4 Retriever (standalone reference implementation)
3
+ ===============================================================================
4
+
5
+ A self-contained, dependency-light (torch only) PyTorch reference implementation
6
+ of the **FlashMemory Retriever** used for sparsifying the DeepSeek-V4
7
+ Compressed-Sparse-Attention (CSA) KV cache.
8
+
9
+ Given the hidden state of a decode token, the retriever predicts which CSA
10
+ KV-cache chunks the next tokens will attend to, so that only the top-scoring
11
+ chunks need to stay resident on the GPU.
12
+
13
+ compressed_k [B, N, 132] uint8 β†’ dequant β†’ k [B, N, HEAD_DIM]
14
+ hidden [B, 4096] β†’ q-proj + RoPE + Hadamard β†’ q [B, N_HEADS, HEAD_DIM]
15
+ β†’ weights_proj β†’ fused_w [B, N_HEADS]
16
+
17
+ score = sigmoid( (relu(k @ q^T) · fused_w).sum(heads) ) ∈ [0, 1]
18
+
19
+ The shipped checkpoint is a *joint* checkpoint holding three independent CSA
20
+ layers (l10 / l12 / l20). At inference time the per-layer sigmoid scores are
21
+ ensembled per chunk (cross-layer ``max`` by default, ``mean`` also supported).
22
+
23
+ This file only depends on ``torch``. The full sglang serving integration
24
+ (KV-cache swap, attention-sink, threshold fallback, per-request routing) is
25
+ NOT part of this open release because it depends on the internal DeepSeek-V4
26
+ CSA framework.
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import math
32
+ from collections import OrderedDict
33
+ from typing import Dict, List, Optional, Union
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+
40
+ # ─────────────────────────────────────────────────────────────────────────────
41
+ # RoPE (YaRN) + Hadamard utilities
42
+ # (copied from the project's utils.py so this release is self-contained)
43
+ # ─────────────────────────────────────────────────────────────────────────────
44
+
45
+
46
+ def _yarn_find_correction_dim(n_rot: float, d_model: int, base: float, max_pos: int) -> float:
47
+ return (d_model * math.log(max_pos / (n_rot * 2 * math.pi))) / (2 * math.log(base))
48
+
49
+
50
+ def precompute_freqs_cis(
51
+ dim: int,
52
+ seqlen: int,
53
+ base: float,
54
+ factor: float,
55
+ original_seq_len: int,
56
+ beta_fast: float,
57
+ beta_slow: float,
58
+ ) -> torch.Tensor:
59
+ """YaRN RoPE frequency precomputation.
60
+
61
+ Returns:
62
+ freqs_cis: [seqlen, dim // 2] complex64
63
+ """
64
+ low = max(math.floor(_yarn_find_correction_dim(beta_fast, dim, base, original_seq_len)), 0)
65
+ high = min(math.ceil(_yarn_find_correction_dim(beta_slow, dim, base, original_seq_len)), dim // 2 - 1)
66
+
67
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) # [dim//2]
68
+
69
+ ramp = torch.zeros(dim // 2)
70
+ for i in range(dim // 2):
71
+ if i < low:
72
+ ramp[i] = 0.0
73
+ elif i >= high:
74
+ ramp[i] = 1.0
75
+ else:
76
+ ramp[i] = (i - low) / max(high - low, 1)
77
+
78
+ mixed = freqs * (1 - ramp) + (freqs / factor) * ramp # [dim//2]
79
+ t = torch.arange(seqlen, dtype=torch.float32)
80
+ angles = torch.outer(t, mixed) # [seqlen, dim//2]
81
+ return torch.polar(torch.ones_like(angles), angles) # complex64
82
+
83
+
84
+ def apply_rope(
85
+ q: torch.Tensor,
86
+ freqs_cis: torch.Tensor,
87
+ positions: torch.Tensor,
88
+ rope_dim: int = 64,
89
+ ) -> torch.Tensor:
90
+ """Pure-PyTorch RoPE applied to the last ``rope_dim`` dims of ``q``.
91
+
92
+ Args:
93
+ q: [B, n_heads, head_dim]
94
+ freqs_cis: [max_pos, rope_dim // 2] complex64
95
+ positions: [B] int64
96
+ rope_dim: number of trailing dims to rotate (applied to q[..., -rope_dim:])
97
+
98
+ Returns:
99
+ q after RoPE, same shape as input.
100
+ """
101
+ head_dim = q.shape[-1]
102
+ q_pass = q[..., : head_dim - rope_dim]
103
+ q_rope = q[..., head_dim - rope_dim:]
104
+
105
+ q_c = torch.view_as_complex(
106
+ q_rope.float().reshape(*q_rope.shape[:-1], rope_dim // 2, 2).contiguous()
107
+ ) # [B, H, rope_dim//2]
108
+
109
+ # Clamp positions into the RoPE table range. The freqs_cis table covers
110
+ # max_position entries; tokens beyond it get clamped to the last entry
111
+ # (YaRN extrapolation already makes the tail an approximation, so a few
112
+ # clamped ultra-long positions are far better than an out-of-bounds gather).
113
+ positions = positions.clamp(0, freqs_cis.shape[0] - 1)
114
+
115
+ freqs = freqs_cis[positions].unsqueeze(1) # [B, 1, rope_dim//2]
116
+ q_rot = torch.view_as_real(q_c * freqs).reshape(*q_rope.shape).to(q.dtype)
117
+ return torch.cat([q_pass, q_rot], dim=-1)
118
+
119
+
120
+ def hadamard_transform(x: torch.Tensor) -> torch.Tensor:
121
+ """Normalized Walsh-Hadamard transform over the last dim (must be a power of 2).
122
+
123
+ x: [..., d] β†’ [..., d] (normalized by 1/sqrt(d))
124
+ """
125
+ *leading, d = x.shape
126
+ assert d > 0 and (d & (d - 1)) == 0, f"last dim {d} must be a power of 2"
127
+ h = x.float()
128
+ s = 1
129
+ while s < d:
130
+ h = h.view(*leading, d // (2 * s), 2, s)
131
+ a, b = h[..., 0, :], h[..., 1, :]
132
+ h = torch.stack([a + b, a - b], dim=-2).view(*leading, d)
133
+ s *= 2
134
+ return h / math.sqrt(d)
135
+
136
+
137
+ # ─────────────────────────────────────────────────────────────────────────────
138
+ # compressed-K dequantization
139
+ # ─────────────────────────────────────────────────────────────────────────────
140
+
141
+
142
+ def dequant_compressed_k(compressed_k: torch.Tensor, head_dim: int = 128) -> torch.Tensor:
143
+ """Dequantize compressed CSA keys.
144
+
145
+ Each compressed key is ``head_dim + 4`` bytes:
146
+ bytes[:head_dim] β€” float8_e4m3 quantized key values (1 byte each)
147
+ bytes[head_dim:+4] β€” a single float32 per-chunk scale
148
+
149
+ Args:
150
+ compressed_k: [..., head_dim + 4] uint8
151
+ head_dim: number of key dims (default 128)
152
+
153
+ Returns:
154
+ k: [..., head_dim] float32 ( = fp8_values * scale )
155
+ """
156
+ assert compressed_k.dtype == torch.uint8, (
157
+ f"compressed_k must be uint8, got {compressed_k.dtype}"
158
+ )
159
+ assert compressed_k.shape[-1] == head_dim + 4, (
160
+ f"compressed_k last dim must be {head_dim + 4}, got {compressed_k.shape[-1]}"
161
+ )
162
+
163
+ fp8_bytes = compressed_k[..., :head_dim].contiguous() # uint8 [..., head_dim]
164
+ k_fp8 = fp8_bytes.view(torch.float8_e4m3fn).float() # [..., head_dim]
165
+
166
+ scale_bytes = compressed_k[..., head_dim:head_dim + 4].contiguous() # uint8 [..., 4]
167
+ scale = scale_bytes.view(torch.float32) # [..., 1]
168
+
169
+ return k_fp8 * scale # broadcast β†’ [..., head_dim]
170
+
171
+
172
+ # ─────────────────────────────────────────────────────────────────────────────
173
+ # per-layer scorer module
174
+ # ─────────────────────────────────────────────────────────────────────────────
175
+
176
+
177
+ class _LayerScorer(nn.Module):
178
+ """Holds one CSA layer's retriever weights and computes its logits.
179
+
180
+ Weights are stored as (non-trainable) buffers so ``.to(device)`` / ``.half()``
181
+ move them along with the parent module.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ wq_a: torch.Tensor, # [Q_LORA_RANK, 4096]
187
+ wq_b: torch.Tensor, # [N_HEADS * HEAD_DIM, Q_LORA_RANK]
188
+ q_norm_weight: torch.Tensor, # [Q_LORA_RANK]
189
+ weights_proj: torch.Tensor, # [N_HEADS, 4096]
190
+ n_heads: int,
191
+ head_dim: int,
192
+ rope_dim: int,
193
+ rms_norm_eps: float,
194
+ weight_scale: float,
195
+ ):
196
+ super().__init__()
197
+ self.register_buffer("wq_a", wq_a.to(torch.float32), persistent=False)
198
+ self.register_buffer("wq_b", wq_b.to(torch.float32), persistent=False)
199
+ self.register_buffer("q_norm_weight", q_norm_weight.to(torch.float32), persistent=False)
200
+ self.register_buffer("weights_proj", weights_proj.to(torch.float32), persistent=False)
201
+ self.n_heads = n_heads
202
+ self.head_dim = head_dim
203
+ self.rope_dim = rope_dim
204
+ self.rms_norm_eps = rms_norm_eps
205
+ self.weight_scale = weight_scale
206
+
207
+ def _rmsnorm(self, x: torch.Tensor) -> torch.Tensor:
208
+ x_f = x.float()
209
+ norm = torch.sqrt(x_f.pow(2).mean(dim=-1, keepdim=True) + self.rms_norm_eps)
210
+ return x_f / norm * self.q_norm_weight
211
+
212
+ @torch.no_grad()
213
+ def logits(
214
+ self,
215
+ hidden: torch.Tensor, # [B, 4096]
216
+ k_float: torch.Tensor, # [B, N, head_dim] (already dequantized)
217
+ positions: torch.Tensor, # [B] int64
218
+ freqs_cis: torch.Tensor, # [max_pos, rope_dim//2] complex64
219
+ ) -> torch.Tensor:
220
+ """Return raw (pre-sigmoid) logits [B, N] for this layer."""
221
+ x = hidden.float()
222
+ B = x.shape[0]
223
+
224
+ # ── Q side ──────────────────────────────────────────────────────────
225
+ q_lora = F.linear(x, self.wq_a) # [B, Q_LORA_RANK]
226
+ q_lora = self._rmsnorm(q_lora) # [B, Q_LORA_RANK]
227
+ q = F.linear(q_lora, self.wq_b) # [B, N_HEADS * HEAD_DIM]
228
+ q = q.view(B, self.n_heads, self.head_dim) # [B, N_HEADS, HEAD_DIM]
229
+ # RoPE is applied in bf16 then cast back to float32 to match the trained
230
+ # / deployed scoring path exactly.
231
+ q = apply_rope(q.to(torch.bfloat16), freqs_cis, positions.to(torch.int64),
232
+ rope_dim=self.rope_dim).float()
233
+ q = hadamard_transform(q) # [B, N_HEADS, HEAD_DIM]
234
+
235
+ per_head_w = F.linear(x, self.weights_proj) # [B, N_HEADS]
236
+ fused_w = per_head_w * self.weight_scale # [B, N_HEADS]
237
+
238
+ # ── Score: relu(k @ q^T) weighted-sum over heads ────────────────────
239
+ # q: [B, H, D], k_float: [B, N, D] β†’ [B, N, H]
240
+ scores_per_head = F.relu(torch.einsum("bhd,bnd->bnh", q, k_float)) # [B, N, H]
241
+ logits = (scores_per_head * fused_w.unsqueeze(1)).sum(-1) # [B, N]
242
+ return logits
243
+
244
+
245
+ # ─────────────────────────────────────────────────────────────────────────────
246
+ # FlashMemoryRetriever
247
+ # ─────────────────────────────────────────────────────────────────────────────
248
+
249
+
250
+ class FlashMemoryRetriever(nn.Module):
251
+ """Multi-layer FlashMemory retriever (joint checkpoint).
252
+
253
+ Loads a joint checkpoint whose state-dict keys look like
254
+ ``retrievers.l10.wq_a.weight``, builds one ``_LayerScorer`` per CSA layer,
255
+ and scores compressed-K chunks against a decode token's hidden state.
256
+
257
+ Typical usage::
258
+
259
+ model = FlashMemoryRetriever.from_checkpoint("flashmemory_ds_v4.safetensors",
260
+ device="cuda")
261
+ per_layer = model(hidden_state, compressed_k, positions) # {"l10": [B,N], ...}
262
+ scores = model.ensemble(hidden_state, compressed_k, positions, mode="max") # [B,N]
263
+ """
264
+
265
+ # RoPE / normalization constants (identical across all CSA layers).
266
+ HEAD_DIM = 128
267
+ ROPE_DIM = 64
268
+ ROPE_BASE = 160000.0
269
+ ROPE_FACTOR = 16.0
270
+ ROPE_ORIGINAL_SEQ_LEN = 65536
271
+ ROPE_BETA_FAST = 32.0
272
+ ROPE_BETA_SLOW = 1.0
273
+ RMS_NORM_EPS = 1e-6
274
+
275
+ def __init__(
276
+ self,
277
+ layer_states: "OrderedDict[str, Dict[str, torch.Tensor]]",
278
+ device: Union[str, torch.device] = "cpu",
279
+ max_position: int = 524288,
280
+ head_dim: Optional[int] = None,
281
+ ):
282
+ """
283
+ Args:
284
+ layer_states: ordered mapping ``layer_name -> {"wq_a.weight": ...,
285
+ "wq_b.weight": ..., "q_norm_weight": ..., "weights_proj.weight": ...}``.
286
+ Layer names are arbitrary (e.g. ``"l10"``); ordering is preserved.
287
+ device: device to place the model on.
288
+ max_position: RoPE table length. Must cover the largest token position
289
+ ever scored; positions beyond it are clamped (RoPE becomes an
290
+ approximation). Default 524288; can be raised to 1_048_576 (1M) for
291
+ full-length DeepSeek-V4 contexts.
292
+ head_dim: key/head dimension. Defaults to ``HEAD_DIM`` (128).
293
+ """
294
+ super().__init__()
295
+ assert layer_states, "FlashMemoryRetriever needs at least one layer"
296
+ device = torch.device(device)
297
+ self.head_dim = head_dim if head_dim is not None else self.HEAD_DIM
298
+ self.max_position = max_position
299
+ self.layer_names: List[str] = list(layer_states.keys())
300
+
301
+ # Precompute the (shared) YaRN RoPE table once.
302
+ freqs_cis = precompute_freqs_cis(
303
+ dim=self.ROPE_DIM,
304
+ seqlen=max_position,
305
+ base=self.ROPE_BASE,
306
+ factor=self.ROPE_FACTOR,
307
+ original_seq_len=self.ROPE_ORIGINAL_SEQ_LEN,
308
+ beta_fast=self.ROPE_BETA_FAST,
309
+ beta_slow=self.ROPE_BETA_SLOW,
310
+ )
311
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
312
+
313
+ # Build one scorer per layer.
314
+ self.scorers = nn.ModuleDict()
315
+ for name, st in layer_states.items():
316
+ wq_b = st["wq_b.weight"]
317
+ n_heads = wq_b.shape[0] // self.head_dim
318
+ weight_scale = self.head_dim ** -0.5 * n_heads ** -0.5
319
+ self.scorers[name] = _LayerScorer(
320
+ wq_a=st["wq_a.weight"],
321
+ wq_b=wq_b,
322
+ q_norm_weight=st["q_norm_weight"],
323
+ weights_proj=st["weights_proj.weight"],
324
+ n_heads=n_heads,
325
+ head_dim=self.head_dim,
326
+ rope_dim=self.ROPE_DIM,
327
+ rms_norm_eps=self.RMS_NORM_EPS,
328
+ weight_scale=weight_scale,
329
+ )
330
+
331
+ self.n_heads = next(iter(self.scorers.values())).n_heads
332
+ self.to(device)
333
+
334
+ # ── construction helpers ────────────────────────────────────────────────
335
+
336
+ @staticmethod
337
+ def _split_joint_state(
338
+ state: Dict[str, torch.Tensor],
339
+ layers: Optional[List[str]] = None,
340
+ ) -> "OrderedDict[str, Dict[str, torch.Tensor]]":
341
+ """Split a joint state-dict (keys ``retrievers.l{ID}.*``) into per-layer dicts."""
342
+ is_joint = any(k.startswith("retrievers.") for k in state.keys())
343
+ if not is_joint:
344
+ raise ValueError(
345
+ "State dict is not in joint 'retrievers.l{ID}.*' format. "
346
+ f"Got keys e.g. {list(state.keys())[:3]}"
347
+ )
348
+ found = sorted({k.split(".")[1] for k in state if k.startswith("retrievers.")})
349
+ use_layers = layers if layers is not None else found
350
+ out: "OrderedDict[str, Dict[str, torch.Tensor]]" = OrderedDict()
351
+ wanted = ("wq_a.weight", "wq_b.weight", "q_norm_weight", "weights_proj.weight")
352
+ for lname in use_layers:
353
+ prefix = f"retrievers.{lname}."
354
+ sub = {k[len(prefix):]: v for k, v in state.items() if k.startswith(prefix)}
355
+ if not sub:
356
+ raise ValueError(
357
+ f"Layer {lname!r} not found in checkpoint. Available: {found}"
358
+ )
359
+ missing = [w for w in wanted if w not in sub]
360
+ if missing:
361
+ raise ValueError(f"Layer {lname!r} missing weights {missing}")
362
+ out[lname] = {w: sub[w] for w in wanted}
363
+ return out
364
+
365
+ @classmethod
366
+ def from_checkpoint(
367
+ cls,
368
+ ckpt_path: str,
369
+ device: Union[str, torch.device] = "cpu",
370
+ max_position: int = 524288,
371
+ layers: Optional[List[str]] = None,
372
+ ) -> "FlashMemoryRetriever":
373
+ """Load a joint checkpoint and build the retriever.
374
+
375
+ Supports both ``.pt`` (``torch.save`` state-dict) and ``.safetensors``
376
+ (HuggingFace convention). Only the learned weights (``wq_a/wq_b/
377
+ q_norm_weight/weights_proj``) are read; the RoPE ``freqs_cis`` table is
378
+ recomputed locally, so a slim ``.safetensors`` loads identically.
379
+
380
+ Args:
381
+ ckpt_path: path to the joint checkpoint (``.pt`` or ``.safetensors``).
382
+ device: device to load onto.
383
+ max_position: RoPE table length (see ``__init__``).
384
+ layers: optional subset of layer names (e.g. ``["l10", "l20"]``). If
385
+ None, all layers found in the checkpoint are used.
386
+ """
387
+ if str(ckpt_path).endswith(".safetensors"):
388
+ from safetensors.torch import load_file
389
+ state = load_file(ckpt_path, device="cpu")
390
+ else:
391
+ state = torch.load(ckpt_path, map_location="cpu", weights_only=True)
392
+ layer_states = cls._split_joint_state(state, layers=layers)
393
+ return cls(layer_states, device=device, max_position=max_position)
394
+
395
+ # ── inference ────────────────────────────────────────────────────────────
396
+
397
+ @torch.no_grad()
398
+ def forward(
399
+ self,
400
+ hidden_state: torch.Tensor, # [B, 4096]
401
+ compressed_k: torch.Tensor, # [B, N, head_dim + 4] uint8
402
+ positions: torch.Tensor, # [B] int64
403
+ apply_sigmoid: bool = True,
404
+ ) -> "OrderedDict[str, torch.Tensor]":
405
+ """Score the compressed-K chunks with every CSA layer.
406
+
407
+ Args:
408
+ hidden_state: [B, 4096] decode-token hidden states.
409
+ compressed_k: [B, N, head_dim + 4] uint8 compressed keys (shared across
410
+ layers in this reference impl β€” see note below).
411
+ positions: [B] int64 token positions (for RoPE).
412
+ apply_sigmoid: if True (default) return sigmoid scores ∈ [0, 1];
413
+ if False return raw logits.
414
+
415
+ Returns:
416
+ OrderedDict ``{layer_name: scores [B, N]}``.
417
+
418
+ Note:
419
+ In the production DeepSeek-V4 CSA system each layer has its *own*
420
+ compressed-K buffer. This reference impl scores all layers against the
421
+ single ``compressed_k`` you pass, which is the right behavior for the
422
+ standalone algorithm demo. If you have per-layer K, call this once per
423
+ layer with that layer's K, or use ``score_layer``.
424
+ """
425
+ device = self.freqs_cis.device
426
+ hidden_state = hidden_state.to(device)
427
+ compressed_k = compressed_k.to(device)
428
+ positions = positions.to(device)
429
+
430
+ k_float = dequant_compressed_k(compressed_k, head_dim=self.head_dim) # [B, N, D]
431
+
432
+ out: "OrderedDict[str, torch.Tensor]" = OrderedDict()
433
+ for name, scorer in self.scorers.items():
434
+ logits = scorer.logits(hidden_state, k_float, positions, self.freqs_cis)
435
+ out[name] = torch.sigmoid(logits) if apply_sigmoid else logits
436
+ return out
437
+
438
+ @torch.no_grad()
439
+ def score_layer(
440
+ self,
441
+ layer_name: str,
442
+ hidden_state: torch.Tensor,
443
+ compressed_k: torch.Tensor,
444
+ positions: torch.Tensor,
445
+ apply_sigmoid: bool = True,
446
+ ) -> torch.Tensor:
447
+ """Score a single layer (useful when each layer has its own K)."""
448
+ device = self.freqs_cis.device
449
+ k_float = dequant_compressed_k(compressed_k.to(device), head_dim=self.head_dim)
450
+ logits = self.scorers[layer_name].logits(
451
+ hidden_state.to(device), k_float, positions.to(device), self.freqs_cis
452
+ )
453
+ return torch.sigmoid(logits) if apply_sigmoid else logits
454
+
455
+ @torch.no_grad()
456
+ def ensemble(
457
+ self,
458
+ hidden_state: torch.Tensor,
459
+ compressed_k: torch.Tensor,
460
+ positions: torch.Tensor,
461
+ mode: str = "max",
462
+ ) -> torch.Tensor:
463
+ """Cross-layer ensemble of per-chunk sigmoid scores.
464
+
465
+ Args:
466
+ mode: ``"max"`` (default) or ``"mean"`` over the per-layer sigmoid
467
+ scores, per chunk.
468
+
469
+ Returns:
470
+ scores [B, N] ∈ [0, 1].
471
+ """
472
+ assert mode in ("max", "mean"), f"unknown ensemble mode: {mode!r}"
473
+ per_layer = self.forward(hidden_state, compressed_k, positions, apply_sigmoid=True)
474
+ stacked = torch.stack(list(per_layer.values()), dim=0) # [L, B, N]
475
+ if mode == "max":
476
+ return stacked.amax(dim=0)
477
+ return stacked.mean(dim=0)
478
+
479
+ @torch.no_grad()
480
+ def select_topk(
481
+ self,
482
+ hidden_state: torch.Tensor,
483
+ compressed_k: torch.Tensor,
484
+ positions: torch.Tensor,
485
+ top_k: Optional[int] = None,
486
+ threshold: Optional[float] = None,
487
+ mode: str = "max",
488
+ ) -> torch.Tensor:
489
+ """Return a boolean keep-mask [B, N] of selected chunks.
490
+
491
+ Exactly one of ``top_k`` / ``threshold`` should be given. With ``top_k``
492
+ the top-k highest-scoring chunks per row are kept; with ``threshold`` all
493
+ chunks whose ensembled sigmoid score exceeds the threshold are kept.
494
+ """
495
+ scores = self.ensemble(hidden_state, compressed_k, positions, mode=mode) # [B, N]
496
+ B, N = scores.shape
497
+ if (top_k is None) == (threshold is None):
498
+ raise ValueError("Provide exactly one of top_k or threshold")
499
+ if threshold is not None:
500
+ return scores > threshold
501
+ k = min(top_k, N)
502
+ keep = torch.zeros(B, N, dtype=torch.bool, device=scores.device)
503
+ idx = scores.topk(k, dim=-1).indices
504
+ keep.scatter_(1, idx, True)
505
+ return keep
toy_flashmemory_inference.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ toy_flashmemory_inference.py β€” Toy sparse-decode loop driven by the FlashMemory Retriever
3
+ =========================================================================================
4
+
5
+ A minimal, torch-only illustration of how the FlashMemory Retriever controls CSA
6
+ memory recall during decode. Every 64 steps the retriever scores all N compressed-K
7
+ chunks against the current decode hidden state, selects the top-K (or thresholded)
8
+ ones to keep, and the rest are masked from attention β€” exactly as if their KV were
9
+ never recalled onto the GPU.
10
+
11
+ This is NOT a real DeepSeek-V4. The "decoder" is a few toy layers with random
12
+ weights. But the retriever, its scoring math, and the decode-time control flow
13
+ are all real.
14
+
15
+ Run::
16
+
17
+ python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import math
24
+ import os
25
+ import sys
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+
31
+ # Ensure sibling retriever.py is importable (works from any cwd).
32
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
33
+ from retriever import FlashMemoryRetriever, dequant_compressed_k # noqa: E402
34
+
35
+
36
+ HIDDEN_DIM = 4096 # fixed: the retriever consumes a [B, 4096] decode hidden state
37
+
38
+
39
+ # ─────────────────────────────────────────────────────────────────────────────
40
+ # Mock CSA KV-cache: N compressed chunks, each [head_dim + 4] uint8
41
+ # (this is the *indexer's* quantized-K representation that the retriever scores)
42
+ # ─────────────────────────────────────────────────────────────────────────────
43
+ def make_mock_compressed_k(
44
+ batch: int,
45
+ n_chunks: int,
46
+ head_dim: int = 128,
47
+ device: str = "cpu",
48
+ seed: int = 0,
49
+ ) -> torch.Tensor:
50
+ """Build a valid mock ``compressed_k`` tensor ``[B, N, head_dim + 4]`` uint8.
51
+
52
+ This mirrors how the real CSA cache stores a compressed key per chunk:
53
+ bytes[:head_dim] β€” float8_e4m3 quantized key values (1 byte each)
54
+ bytes[head_dim:+4] β€” one float32 per-chunk dequant scale
55
+
56
+ In a real FlashMemory run these bytes are produced during *prefill*, when the
57
+ historical KV is compressed and stored. Here we just sample them randomly β€”
58
+ the retriever still runs its exact scoring path over them.
59
+ """
60
+ g = torch.Generator(device=device).manual_seed(seed)
61
+
62
+ # 1) fp8 key bytes
63
+ k_vals = torch.randn(batch, n_chunks, head_dim, generator=g, device=device) * 0.5
64
+ fp8_bytes = k_vals.to(torch.float8_e4m3fn).view(torch.uint8) # [B, N, head_dim]
65
+
66
+ # 2) float32 per-chunk scale β†’ 4 uint8 bytes
67
+ scale = (0.05 + 0.15 * torch.rand(batch, n_chunks, 1, generator=g, device=device)).float()
68
+ scale_bytes = scale.view(torch.uint8) # [B, N, 4]
69
+
70
+ compressed = torch.cat([fp8_bytes, scale_bytes], dim=-1) # [B, N, head_dim + 4]
71
+ assert compressed.shape[-1] == head_dim + 4
72
+ return compressed.contiguous()
73
+
74
+
75
+ # ─────────────────────────────────────────────────────────────────────────────
76
+ # Toy decoder (random weights). Only exists to emit a [B,4096] hidden state
77
+ # each step and own a memory cross-attention over N CSA chunks that the
78
+ # retriever's keep-mask sparsifies.
79
+ # ─────────────────────────────────────────────────────────────────────────────
80
+ def _rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
81
+ norm = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + eps)
82
+ return (x.float() * norm).to(x.dtype) * weight
83
+
84
+
85
+ class ToyMemoryDecoder(nn.Module):
86
+ """A few layers of toy memory cross-attention + MLP (random weights)."""
87
+
88
+ def __init__(
89
+ self,
90
+ n_chunks: int,
91
+ n_layers: int = 2,
92
+ n_heads: int = 8,
93
+ vocab_size: int = 512,
94
+ device: str = "cpu",
95
+ seed: int = 0,
96
+ ):
97
+ super().__init__()
98
+ torch.manual_seed(seed)
99
+ self.hidden_dim = HIDDEN_DIM
100
+ self.n_layers = n_layers
101
+ self.n_heads = n_heads
102
+ self.head_dim = self.hidden_dim // n_heads
103
+ self.n_chunks = n_chunks
104
+
105
+ # Token embedding (toy; vocab is meaningless).
106
+ self.embed = nn.Embedding(vocab_size, self.hidden_dim)
107
+
108
+ # Decoder-space memory bank: one vector per CSA chunk (separate from the
109
+ # retriever's compressed_k β€” both index the same N chunks).
110
+ self.register_buffer("memory", torch.randn(n_chunks, self.hidden_dim) * 0.02)
111
+
112
+ # Per-layer projections + norms.
113
+ self.wq = nn.ModuleList(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False) for _ in range(n_layers))
114
+ self.wk = nn.ModuleList(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False) for _ in range(n_layers))
115
+ self.wv = nn.ModuleList(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False) for _ in range(n_layers))
116
+ self.wo = nn.ModuleList(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False) for _ in range(n_layers))
117
+ self.mlp_up = nn.ModuleList(nn.Linear(self.hidden_dim, 2 * self.hidden_dim, bias=False) for _ in range(n_layers))
118
+ self.mlp_down = nn.ModuleList(nn.Linear(2 * self.hidden_dim, self.hidden_dim, bias=False) for _ in range(n_layers))
119
+ self.attn_norm = nn.ParameterList(nn.Parameter(torch.ones(self.hidden_dim)) for _ in range(n_layers))
120
+ self.mlp_norm = nn.ParameterList(nn.Parameter(torch.ones(self.hidden_dim)) for _ in range(n_layers))
121
+ self.final_norm = nn.Parameter(torch.ones(self.hidden_dim))
122
+ self.lm_head = nn.Linear(self.hidden_dim, vocab_size, bias=False)
123
+
124
+ self.to(device)
125
+ self.eval()
126
+
127
+ @torch.no_grad()
128
+ def _memory_attention(self, x: torch.Tensor, layer: int, keep_mask: torch.Tensor | None) -> torch.Tensor:
129
+ """Cross-attention of the current token(s) over the N memory chunks.
130
+
131
+ Args:
132
+ x: [B, hidden] current-token hidden state(s).
133
+ keep_mask: [B, N] bool, True = chunk recalled/kept. ``None`` = keep all
134
+ (the dense path used during prefill / cold-start).
135
+
136
+ Chunks with ``keep_mask == False`` get their attention logit set to
137
+ ``-inf`` β†’ softmax weight 0 β†’ they contribute nothing. THIS is our
138
+ simulation of "the chunk was not recalled onto the GPU".
139
+ """
140
+ B = x.shape[0]
141
+ H, D = self.n_heads, self.head_dim
142
+
143
+ q = self.wq[layer](x).view(B, H, 1, D) # [B, H, 1, D]
144
+ k = self.wk[layer](self.memory).view(self.n_chunks, H, D).permute(1, 0, 2) # [H, N, D]
145
+ v = self.wv[layer](self.memory).view(self.n_chunks, H, D).permute(1, 0, 2) # [H, N, D]
146
+
147
+ # [B, H, 1, N] attention logits over the N memory chunks.
148
+ logits = torch.einsum("bhqd,hnd->bhqn", q, k) / math.sqrt(D)
149
+ if keep_mask is not None:
150
+ # Broadcast [B, N] β†’ [B, 1, 1, N] and mask the dropped chunks.
151
+ drop = ~keep_mask.view(B, 1, 1, self.n_chunks)
152
+ logits = logits.masked_fill(drop, float("-inf"))
153
+
154
+ attn = torch.softmax(logits, dim=-1) # [B, H, 1, N]
155
+ out = torch.einsum("bhqn,hnd->bhqd", attn, v).reshape(B, self.hidden_dim)
156
+ return self.wo[layer](out)
157
+
158
+ @torch.no_grad()
159
+ def step(
160
+ self,
161
+ token_ids: torch.Tensor, # [B] int64
162
+ keep_mask: torch.Tensor | None, # [B, N] bool, or None for dense
163
+ ) -> tuple[torch.Tensor, torch.Tensor]:
164
+ """One decode step. Returns (hidden [B, 4096], next-token logits [B, vocab])."""
165
+ x = self.embed(token_ids) # [B, hidden]
166
+ for layer in range(self.n_layers):
167
+ x = x + self._memory_attention(_rmsnorm(x, self.attn_norm[layer]), layer, keep_mask)
168
+ h = _rmsnorm(x, self.mlp_norm[layer])
169
+ x = x + self.mlp_down[layer](F.gelu(self.mlp_up[layer](h)))
170
+ hidden = _rmsnorm(x, self.final_norm) # [B, 4096] ← feeds retriever
171
+ return hidden, self.lm_head(hidden)
172
+
173
+ @torch.no_grad()
174
+ def prefill(self, prefill_ids: torch.Tensor) -> torch.Tensor:
175
+ """Toy 'prefill': run a short prompt through DENSE memory attention.
176
+
177
+ Returns the last token's hidden state, which seeds the very first
178
+ retrieval cycle (the indexer needs a query hidden state to score against).
179
+ Prefill is intentionally dense (keep_mask=None): the model sees the whole
180
+ history before decoding begins.
181
+ """
182
+ hidden = None
183
+ for t in range(prefill_ids.shape[1]):
184
+ hidden, _ = self.step(prefill_ids[:, t], keep_mask=None)
185
+ return hidden # [B, 4096]
186
+
187
+
188
+ # ─────────────────────────────────────────────────────────────────────────────
189
+ # Retrieval helper: scores β†’ keep-mask (top-K or threshold)
190
+ # ─────────────────────────────────────────────────────────────────────────────
191
+ def scores_to_keep_mask(
192
+ scores: torch.Tensor, # [B, N] sigmoid scores ∈ [0, 1]
193
+ select_mode: str,
194
+ top_k: int,
195
+ threshold: float,
196
+ ) -> torch.Tensor:
197
+ """Turn per-chunk retriever scores into a boolean keep-mask [B, N]."""
198
+ B, N = scores.shape
199
+ if select_mode == "topk":
200
+ k = min(top_k, N)
201
+ keep = torch.zeros(B, N, dtype=torch.bool, device=scores.device)
202
+ idx = scores.topk(k, dim=-1).indices
203
+ keep.scatter_(1, idx, True)
204
+ return keep
205
+ elif select_mode == "threshold":
206
+ return scores > threshold
207
+ raise ValueError(f"unknown select_mode: {select_mode!r}")
208
+
209
+
210
+ # ─────────────────────────────────────────────────────────────────────────────
211
+ # main
212
+ # ─────────────────────────────────────────────────────────────────────────────
213
+ def main():
214
+ ap = argparse.ArgumentParser(
215
+ description="Toy DeepSeek-V4-FlashMemory sparse-decode loop driven by the FlashMemory Retriever"
216
+ )
217
+ ap.add_argument("--ckpt", required=True, help="path to the FlashMemory DS-V4 joint checkpoint (.pt)")
218
+ ap.add_argument("--device", default="cpu", help="cpu or cuda (default: cpu)")
219
+ ap.add_argument("--batch", type=int, default=1, help="number of parallel decode sequences")
220
+ ap.add_argument("--n-chunks", type=int, default=256, help="number of CSA memory chunks (the long history)")
221
+ ap.add_argument("--steps", type=int, default=192, help="number of decode steps to generate")
222
+ ap.add_argument("--retrieval-interval", type=int, default=64,
223
+ help="run the retriever every N decode steps (FlashMemory default 64)")
224
+ ap.add_argument("--select-mode", default="topk", choices=["topk", "threshold"],
225
+ help="how to turn scores into a keep-mask")
226
+ ap.add_argument("--top-k", type=int, default=64, help="chunks to recall per cycle (select-mode=topk)")
227
+ ap.add_argument("--threshold", type=float, default=0.5, help="sigmoid keep threshold (select-mode=threshold)")
228
+ ap.add_argument("--ensemble", default="max", choices=["max", "mean"], help="cross-layer ensemble mode")
229
+ ap.add_argument("--max-position", type=int, default=524288, help="RoPE table length")
230
+ ap.add_argument("--n-layers", type=int, default=2, help="toy decoder layers")
231
+ ap.add_argument("--seed", type=int, default=0)
232
+ args = ap.parse_args()
233
+
234
+ torch.manual_seed(args.seed)
235
+ device = args.device
236
+ B, N = args.batch, args.n_chunks
237
+
238
+ # ── 1. Load retriever ──────────────────────────────────────────────────────
239
+ print(f"FlashMemory DS-V4 β€” toy sparse-decode loop")
240
+ print(f"[load] {args.ckpt}")
241
+ retriever = FlashMemoryRetriever.from_checkpoint(
242
+ args.ckpt, device=device, max_position=args.max_position
243
+ )
244
+ retriever.eval()
245
+ print(f"[load] layers={retriever.layer_names} n_heads={retriever.n_heads} "
246
+ f"head_dim={retriever.head_dim}")
247
+
248
+ # ── 2. Build toy decoder + mock CSA memory ─────────────────────────────────
249
+ decoder = ToyMemoryDecoder(n_chunks=N, n_layers=args.n_layers, device=device, seed=args.seed)
250
+ compressed_k = make_mock_compressed_k(B, N, head_dim=retriever.head_dim,
251
+ device=device, seed=args.seed)
252
+ print(f"[init] decoder: {args.n_layers} layers, {decoder.n_heads} heads | "
253
+ f"CSA memory: {N} chunks [{retriever.head_dim + 4}] uint8")
254
+
255
+ # ── 3. Prefill ─────────────────────────────────────────────────────────────
256
+ prefill_len = 8
257
+ prefill_ids = torch.randint(0, 512, (B, prefill_len), device=device)
258
+ last_hidden = decoder.prefill(prefill_ids)
259
+ base_pos = prefill_len
260
+ last_pos = torch.full((B,), prefill_len - 1, dtype=torch.int64, device=device)
261
+
262
+ sel_desc = (f"top-K={args.top_k}" if args.select_mode == "topk"
263
+ else f"sigmoid>{args.threshold}")
264
+ print(f"\n[decode] {args.steps} steps, retriever every {args.retrieval_interval} steps "
265
+ f"({args.select_mode} [{sel_desc}], ensemble={args.ensemble})")
266
+ print("-" * 60)
267
+
268
+ # ── 4. Decode loop ──────────────────────────────────────────────────────────
269
+ keep_mask = None
270
+ token = decoder.embed.weight.new_zeros(B, dtype=torch.int64)
271
+ keep_ratios: list[float] = []
272
+ cycle = 0
273
+
274
+ for t in range(args.steps):
275
+ abs_pos = base_pos + t
276
+
277
+ if t % args.retrieval_interval == 0:
278
+ scores = retriever.ensemble(last_hidden, compressed_k, last_pos, mode=args.ensemble)
279
+ keep_mask = scores_to_keep_mask(scores, args.select_mode, args.top_k, args.threshold)
280
+
281
+ n_keep = keep_mask.sum(-1)
282
+ ratio = (n_keep.float() / N)
283
+ keep_ratios.extend(ratio.tolist())
284
+ w_lo = abs_pos
285
+ w_hi = min(abs_pos + args.retrieval_interval, base_pos + args.steps) - 1
286
+
287
+ print(f"[cycle {cycle:>2}] pos {w_lo:>5}..{w_hi:<5} | "
288
+ f"keep {fmt_ratio(ratio, B)} ({int(n_keep[0])}/{N}) | "
289
+ f"score mean={scores.mean():.4f} max={scores.max():.4f}")
290
+ cycle += 1
291
+
292
+ hidden, logits = decoder.step(token, keep_mask)
293
+ token = logits.argmax(-1)
294
+ last_hidden = hidden
295
+ last_pos = torch.full((B,), abs_pos, dtype=torch.int64, device=device)
296
+
297
+ # ── 5. Summary ─────────────────────────────────────────────────────────────
298
+ avg_keep = sum(keep_ratios) / max(len(keep_ratios), 1)
299
+ print("-" * 60)
300
+ print(f"[done] {args.steps} tokens, {cycle} cycles, "
301
+ f"avg keep/cycle: {avg_keep:.1%} β†’ ~{1 - avg_keep:.0%} CSA KV dropped")
302
+ print(f"[note] Dropped chunks are masked to -inf in attention (= KV not recalled to GPU). "
303
+ f"Production swap engine not included in this release.")
304
+
305
+
306
+ def fmt_ratio(t: torch.Tensor, B: int) -> str:
307
+ vals = t.tolist()
308
+ return f"{vals[0]:.1%}" if B == 1 else "[" + ", ".join(f"{v:.1%}" for v in vals) + "]"
309
+
310
+
311
+ if __name__ == "__main__":
312
+ main()
weights/flashmemory_ds_v4.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba20d264c309246f824d4471ccc637061b3b0268fe8e4eecc121474a1e5cd02a
3
+ size 509633992