Commit ·
9437df5
1
Parent(s): 9dcfd27
Streamline model card
Browse files
README.md
CHANGED
|
@@ -6,7 +6,6 @@ tags:
|
|
| 6 |
- retrieval
|
| 7 |
- kv-cache
|
| 8 |
- sparse-attention
|
| 9 |
-
- compress-sparse-attention
|
| 10 |
- long-context
|
| 11 |
- flashmemory
|
| 12 |
datasets:
|
|
@@ -18,321 +17,76 @@ datasets:
|
|
| 18 |
|
| 19 |
# FlashMemory DS-V4 Retriever
|
| 20 |
|
| 21 |
-
A
|
| 22 |
-
|
| 23 |
-
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
on long-context tasks while keeping a small fraction of the KV cache on-device.
|
| 30 |
|
| 31 |
-
|
| 32 |
-
demo**. It depends only on `torch` (plus `numpy` / `safetensors` for convenience).
|
| 33 |
-
|
| 34 |
-
> **Scope note.** The full sglang serving integration — KV-cache swap-in/out,
|
| 35 |
-
> attention-sink, threshold fallback, per-request retriever routing — is **not**
|
| 36 |
-
> included here, because it is tightly coupled to the internal DeepSeek-V4 CSA
|
| 37 |
-
> framework and cannot run outside it. This repository provides the retriever
|
| 38 |
-
> **algorithm reference implementation and trained weights only.**
|
| 39 |
-
|
| 40 |
-
---
|
| 41 |
-
|
| 42 |
-
## Model architecture
|
| 43 |
-
|
| 44 |
-
The retriever scores each compressed-K chunk against the decode token's hidden
|
| 45 |
-
state. For a single CSA layer:
|
| 46 |
-
|
| 47 |
-
```
|
| 48 |
-
hidden [B, 4096]
|
| 49 |
-
→ wq_a (4096 → Q_LORA_RANK)
|
| 50 |
-
→ RMSNorm (q_norm_weight, eps=1e-6)
|
| 51 |
-
→ wq_b (Q_LORA_RANK → N_HEADS * HEAD_DIM)
|
| 52 |
-
→ reshape [B, N_HEADS, HEAD_DIM]
|
| 53 |
-
→ RoPE (YaRN, applied to the last ROPE_DIM=64 dims, base=160000)
|
| 54 |
-
→ Hadamard (normalized Walsh-Hadamard transform)
|
| 55 |
-
→ q [B, N_HEADS, HEAD_DIM]
|
| 56 |
-
|
| 57 |
-
hidden [B, 4096]
|
| 58 |
-
→ weights_proj (4096 → N_HEADS)
|
| 59 |
-
→ × weight_scale (= HEAD_DIM^-0.5 * N_HEADS^-0.5)
|
| 60 |
-
→ fused_w [B, N_HEADS]
|
| 61 |
-
|
| 62 |
-
compressed_k [B, N, HEAD_DIM + 4] (uint8)
|
| 63 |
-
→ bytes[:HEAD_DIM] viewed as float8_e4m3 → dequantize
|
| 64 |
-
→ bytes[HEAD_DIM:] viewed as float32 → per-chunk scale
|
| 65 |
-
→ k [B, N, HEAD_DIM]
|
| 66 |
-
|
| 67 |
-
score_per_head = relu( einsum('bnd,bhd->bnh', k, q) ) # [B, N, N_HEADS]
|
| 68 |
-
logit = (score_per_head * fused_w[:, None, :]).sum(-1) # [B, N]
|
| 69 |
-
score = sigmoid(logit) ∈ [0, 1] # [B, N]
|
| 70 |
-
```
|
| 71 |
-
|
| 72 |
-
**Hyperparameters (FlashMemory DS-V4):** `Q_LORA_RANK = 2048`, `N_HEADS = 128`,
|
| 73 |
-
`HEAD_DIM = 128`, `ROPE_DIM = 64`, `ROPE_BASE = 160000`, `ROPE_FACTOR = 16`,
|
| 74 |
-
`ROPE_ORIGINAL_SEQ_LEN = 65536`, `ROPE_BETA_FAST = 32`, `ROPE_BETA_SLOW = 1`,
|
| 75 |
-
`RMS_NORM_EPS = 1e-6`.
|
| 76 |
-
|
| 77 |
-
### Joint multi-layer checkpoint + ensemble
|
| 78 |
-
|
| 79 |
-
FlashMemory DS-V4 is a **joint checkpoint** holding three independent CSA layers
|
| 80 |
-
(`l10`, `l12`, `l20`), each with its own weights. At inference time the per-layer
|
| 81 |
-
sigmoid scores are **ensembled per chunk** — cross-layer `max` (default) or
|
| 82 |
-
`mean` — to produce a single keep/drop decision per chunk.
|
| 83 |
-
|
| 84 |
-
---
|
| 85 |
-
|
| 86 |
-
## What is FlashMemory DS-V4?
|
| 87 |
-
|
| 88 |
-
FlashMemory DS-V4 is part of the latest retraining generation of these retrievers. In the
|
| 89 |
-
project's downstream evaluation it stays close to the full-attention baseline on
|
| 90 |
-
long-context tasks (e.g. RULER, LongMemEval, LongBench V2) while keeping only a
|
| 91 |
-
small fraction of the CSA KV cache on-device (≈90% KV reduction in the deployment
|
| 92 |
-
sweet spot for reasoning-heavy long-context tasks). Precise-needle retrieval
|
| 93 |
-
tasks need an extra threshold-fallback mechanism in the serving layer (not part
|
| 94 |
-
of this standalone release).
|
| 95 |
-
|
| 96 |
-
---
|
| 97 |
-
|
| 98 |
-
## Installation
|
| 99 |
-
|
| 100 |
-
```bash
|
| 101 |
-
pip install -r requirements.txt
|
| 102 |
-
```
|
| 103 |
-
|
| 104 |
-
Only `torch` is strictly required to run the model and demo. `float8_e4m3`
|
| 105 |
-
tensor support requires a reasonably recent PyTorch (≥ 2.1).
|
| 106 |
-
|
| 107 |
-
---
|
| 108 |
-
|
| 109 |
-
## Running the demo
|
| 110 |
|
| 111 |
```bash
|
|
|
|
| 112 |
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
|
| 113 |
```
|
| 114 |
|
| 115 |
-
|
| 116 |
-
set of `uint8` compressed-K chunks, and token positions), loads the FlashMemory DS-V4
|
| 117 |
-
checkpoint, runs the forward pass, prints the per-layer and ensembled per-chunk
|
| 118 |
-
scores, and demonstrates both **threshold** and **top-K** chunk selection.
|
| 119 |
-
|
| 120 |
-
Useful flags:
|
| 121 |
-
|
| 122 |
-
| Flag | Default | Meaning |
|
| 123 |
-
|------|---------|---------|
|
| 124 |
-
| `--device` | `cpu` | `cpu` or `cuda` |
|
| 125 |
-
| `--batch` | `2` | number of decode tokens |
|
| 126 |
-
| `--n-chunks` | `64` | number of compressed-K chunks |
|
| 127 |
-
| `--top-k` | `16` | top-K chunks to select |
|
| 128 |
-
| `--threshold` | `0.5` | sigmoid keep threshold |
|
| 129 |
-
| `--ensemble` | `max` | cross-layer ensemble mode (`max` / `mean`) |
|
| 130 |
-
| `--max-position` | `524288` | RoPE table length (raise to `1048576` for 1M context) |
|
| 131 |
-
|
| 132 |
-
Example output (CPU, default args):
|
| 133 |
-
|
| 134 |
-
```
|
| 135 |
-
[demo] loaded layers=['l10', 'l12', 'l20'] n_heads=128 head_dim=128 max_position=524288
|
| 136 |
-
[demo] per-layer sigmoid score stats (over all chunks):
|
| 137 |
-
l10: min=0.4474 mean=0.5021 max=0.6416
|
| 138 |
-
...
|
| 139 |
-
[demo] threshold selection (sigmoid > 0.5):
|
| 140 |
-
row 0: keep 64/64 chunks (keep ratio 100.0%)
|
| 141 |
-
row 1: keep 49/64 chunks (keep ratio 76.6%)
|
| 142 |
-
[demo] done. ✅ forward + scoring + selection all ran.
|
| 143 |
-
```
|
| 144 |
-
|
| 145 |
-
> The scores above come from **random mock K**, so they cluster near 0.5 — they
|
| 146 |
-
> are only meaningful on real CSA keys. The demo's purpose is to verify the
|
| 147 |
-
> load → forward → selection path end-to-end.
|
| 148 |
-
|
| 149 |
-
---
|
| 150 |
-
|
| 151 |
-
## Using the model in your own code
|
| 152 |
|
| 153 |
```python
|
| 154 |
-
import torch
|
| 155 |
from retriever import FlashMemoryRetriever
|
| 156 |
|
| 157 |
model = FlashMemoryRetriever.from_checkpoint(
|
| 158 |
-
"weights/flashmemory_ds_v4.safetensors", device="cuda"
|
| 159 |
)
|
| 160 |
|
| 161 |
-
hidden
|
| 162 |
-
|
| 163 |
-
positions
|
| 164 |
-
|
| 165 |
-
# Per-layer sigmoid scores: {"l10": [B, N], "l12": [B, N], "l20": [B, N]}
|
| 166 |
-
per_layer = model(hidden, compressed_k, positions)
|
| 167 |
-
|
| 168 |
-
# Cross-layer ensembled per-chunk scores [B, N] ∈ [0, 1]
|
| 169 |
-
scores = model.ensemble(hidden, compressed_k, positions, mode="max")
|
| 170 |
|
| 171 |
-
|
| 172 |
-
keep
|
| 173 |
-
keep = model.select_topk(hidden, compressed_k, positions, threshold=0.5) # threshold
|
| 174 |
```
|
| 175 |
|
| 176 |
-
**`compressed_k` format
|
| 177 |
-
the first `128` bytes are the `float8_e4m3` quantized key values, the last `4`
|
| 178 |
-
bytes are a single `float32` per-chunk scale. Dequantization is
|
| 179 |
-
`fp8_values.view(float8_e4m3).float() * scale`. See `make_mock_compressed_k` in
|
| 180 |
-
`demo.py` for how to construct a valid tensor.
|
| 181 |
-
|
| 182 |
-
---
|
| 183 |
|
| 184 |
-
##
|
| 185 |
|
| 186 |
-
|
|
|
|
| 187 |
|
| 188 |
-
```bash
|
| 189 |
-
huggingface-cli download <HF_REPO> flashmemory_ds_v4.safetensors --local-dir ./weights
|
| 190 |
-
python demo.py --ckpt ./weights/flashmemory_ds_v4.safetensors
|
| 191 |
```
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
-
|
| 194 |
-
`.safetensors` file. The released `.safetensors` is the **slim** form: it stores
|
| 195 |
-
only the four learned tensors per layer
|
| 196 |
-
(`wq_a.weight`, `wq_b.weight`, `q_norm_weight`, `weights_proj.weight` for
|
| 197 |
-
`l10` / `l12` / `l20`) and **omits the `freqs_cis` RoPE table** (≈400 MB), which
|
| 198 |
-
is recomputed at load time from `max_position`. Loading the slim `.safetensors`
|
| 199 |
-
is bit-for-bit identical to loading the full `.pt` (verified by output match).
|
| 200 |
-
|
| 201 |
-
---
|
| 202 |
-
|
| 203 |
-
## Files
|
| 204 |
-
|
| 205 |
-
| File | Purpose |
|
| 206 |
-
|------|---------|
|
| 207 |
-
| `retriever.py` | `FlashMemoryRetriever` model + RoPE/Hadamard utils + FP8 dequant (torch-only, self-contained) |
|
| 208 |
-
| `demo.py` | minimal runnable demo with mock inputs |
|
| 209 |
-
| `toy_flashmemory_inference.py` | toy DeepSeek-V4-FlashMemory sparse-decode loop showing **how the retriever drives memory recall at inference time** (see below) |
|
| 210 |
-
| `requirements.txt` | `torch`, `safetensors`, `numpy` |
|
| 211 |
-
| `LICENSE` | MIT |
|
| 212 |
-
|
| 213 |
-
---
|
| 214 |
-
|
| 215 |
-
## Toy FlashMemory inference reference (`toy_flashmemory_inference.py`)
|
| 216 |
-
|
| 217 |
-
`demo.py` shows a single `hidden → scores` call. `toy_flashmemory_inference.py`
|
| 218 |
-
is the **next step up**: a tiny, fully-runnable illustration of *how the Lightning
|
| 219 |
-
Indexer Retriever is used inside a DeepSeek-V4-FlashMemory style sparse-decode
|
| 220 |
-
loop* to drive "memory recall".
|
| 221 |
-
|
| 222 |
-
It is intentionally small and pedagogical. It depends only on `torch` and the
|
| 223 |
-
sibling `retriever.py`, and it **reuses the real FlashMemory DS-V4 retriever verbatim** — none
|
| 224 |
-
of the scoring math is re-implemented.
|
| 225 |
-
|
| 226 |
-
### The inference flow it demonstrates
|
| 227 |
-
|
| 228 |
-
```
|
| 229 |
-
┌──────────┐ compress & store ┌────────────────────────────┐
|
| 230 |
-
│ PREFILL │ historical K/V │ CSA KV-cache (the memory) │
|
| 231 |
-
│ (dense │ ──────────────────► │ N compressed chunks, │
|
| 232 |
-
│ attn) │ │ each = [132] uint8 fp8-K │
|
| 233 |
-
└────┬─────┘ └──────────────┬─────────────┘
|
| 234 |
-
│ last hidden state │ scored every 64 steps
|
| 235 |
-
▼ │
|
| 236 |
-
┌──────────────────────── DECODE LOOP ─────────┼──────────────────────────┐
|
| 237 |
-
│ for each decode step t: │ │
|
| 238 |
-
│ hidden = toy_decoder.step(token, keep_mask) │ (sparse memory attn) │
|
| 239 |
-
│ │ │
|
| 240 |
-
│ every RETRIEVAL_INTERVAL (= 64) steps: ▼ │
|
| 241 |
-
│ scores[N] = retriever.ensemble(hidden, compressed_k, pos) │
|
| 242 |
-
│ keep_mask[N] = top-K (or sigmoid > threshold) of scores │
|
| 243 |
-
│ → chunks NOT kept are masked to -inf in the next 64 decode steps │
|
| 244 |
-
│ of memory attention (== "not recalled onto the GPU") │
|
| 245 |
-
└──────────────────────────────────────────────────────────────────────────┘
|
| 246 |
```
|
| 247 |
|
| 248 |
-
|
| 249 |
-
last hidden state seeds the first retrieval cycle (the indexer needs a query
|
| 250 |
-
hidden state to score against). In a real run, prefill is also where the
|
| 251 |
-
historical KV is compressed into the `[N, 132]` `uint8` CSA chunks.
|
| 252 |
-
2. **Decode loop.** Every step the toy decoder produces a `[B, 4096]` hidden state
|
| 253 |
-
and attends over the `N` memory chunks.
|
| 254 |
-
3. **Retrieval cycle (every 64 steps).** The real `FlashMemoryRetriever` scores all
|
| 255 |
-
`N` compressed-K chunks against the current decode hidden state, ensembles the
|
| 256 |
-
per-layer (`l10`/`l12`/`l20`) sigmoid scores, and selects the chunks to keep —
|
| 257 |
-
either **top-K** or **sigmoid > threshold**. This predicts which chunks the
|
| 258 |
-
*next ~64 tokens* will attend to.
|
| 259 |
-
4. **Sparse attention.** For the next 64 steps, chunks **not** selected have their
|
| 260 |
-
memory-attention logits set to `-inf`, so they contribute nothing.
|
| 261 |
-
|
| 262 |
-
### What the masking simulates (important)
|
| 263 |
-
|
| 264 |
-
* This toy does **not** perform any real CPU↔GPU KV-cache transfer. The swap-in /
|
| 265 |
-
swap-out machinery is part of the internal FlashMemory engineering and is **not**
|
| 266 |
-
included in this release.
|
| 267 |
-
* We **simulate memory recall by masking the FlashMemory Retriever's per-chunk
|
| 268 |
-
decisions**: a chunk the retriever did not select gets its attention logit set
|
| 269 |
-
to `-inf`. This is equivalent to *"that chunk's KV was never recalled onto the
|
| 270 |
-
GPU, so it cannot be attended to"* — for the attention output, masking a chunk
|
| 271 |
-
out and never loading it produce the same result.
|
| 272 |
-
* The toy's purpose is to make the **decode-time control flow** concrete: where the
|
| 273 |
-
retriever fires, what it consumes (decode hidden state + compressed CSA keys),
|
| 274 |
-
what it produces (a keep/drop mask), and how that mask sparsifies the next
|
| 275 |
-
window of decode steps.
|
| 276 |
-
|
| 277 |
-
### What it is / is NOT
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
randomly-initialized toy attention/MLP whose only jobs are (a) to emit a
|
| 283 |
-
`[B, 4096]` hidden state for the retriever and (b) to own a memory attention we
|
| 284 |
-
can sparsify. The generated tokens are meaningless.
|
| 285 |
-
|
| 286 |
-
> **The production version cannot be released.** It depends on the internal sglang
|
| 287 |
-
> + DeepSeek-V4 CSA framework (native FP8 indexer, real compressed KV-cache,
|
| 288 |
-
> attention-sink, threshold fallback, per-request routing, and the actual KV swap
|
| 289 |
-
> engine). This file shows the *algorithmic role* of the retriever only.
|
| 290 |
-
|
| 291 |
-
### Run
|
| 292 |
|
| 293 |
```bash
|
| 294 |
python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
|
| 295 |
```
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
| Flag | Default | Meaning |
|
| 300 |
-
|------|---------|---------|
|
| 301 |
-
| `--n-chunks` | `256` | number of CSA memory chunks (the long history) |
|
| 302 |
-
| `--steps` | `192` | decode steps to generate |
|
| 303 |
-
| `--retrieval-interval` | `64` | run the retriever every N steps (FlashMemory default) |
|
| 304 |
-
| `--select-mode` | `topk` | `topk` or `threshold` |
|
| 305 |
-
| `--top-k` | `64` | chunks to recall per cycle (`select-mode=topk`) |
|
| 306 |
-
| `--threshold` | `0.5` | sigmoid keep threshold (`select-mode=threshold`) |
|
| 307 |
-
| `--ensemble` | `max` | cross-layer ensemble mode (`max` / `mean`) |
|
| 308 |
-
| `--batch` | `1` | parallel decode sequences |
|
| 309 |
-
|
| 310 |
-
Example output (CPU, default args — `top-K=64` out of `256` chunks):
|
| 311 |
-
|
| 312 |
-
```
|
| 313 |
-
FlashMemory DS-V4 — toy sparse-decode loop
|
| 314 |
-
[load] weights/flashmemory_ds_v4.safetensors
|
| 315 |
-
[load] layers=['l10', 'l12', 'l20'] n_heads=128 head_dim=128
|
| 316 |
-
[init] decoder: 2 layers, 8 heads | CSA memory: 256 chunks [132] uint8
|
| 317 |
-
|
| 318 |
-
[decode] 192 steps, retriever every 64 steps (topk [top-K=64], ensemble=max)
|
| 319 |
-
------------------------------------------------------------
|
| 320 |
-
[cycle 0] pos 8..71 | keep 25.0% (64/256) | score mean=0.4910 max=0.5445
|
| 321 |
-
[cycle 1] pos 72..135 | keep 25.0% (64/256) | score mean=0.4910 max=0.5445
|
| 322 |
-
...
|
| 323 |
-
------------------------------------------------------------
|
| 324 |
-
[done] 192 tokens, 3 cycles, avg keep/cycle: 25.0% → ~75% CSA KV dropped
|
| 325 |
-
[note] Dropped chunks are masked to -inf in attention (= KV not recalled to GPU).
|
| 326 |
-
```
|
| 327 |
-
|
| 328 |
-
> As in `demo.py`, the scores come from **random mock K** and cluster near 0.5;
|
| 329 |
-
> they are only meaningful on real CSA keys. The toy's value is the *control flow*
|
| 330 |
-
> — watch each retrieval cycle report how many chunks were scored, recalled, and
|
| 331 |
-
> masked out.
|
| 332 |
|
| 333 |
-
|
| 334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
## License
|
| 337 |
|
| 338 |
-
MIT
|
|
|
|
| 6 |
- retrieval
|
| 7 |
- kv-cache
|
| 8 |
- sparse-attention
|
|
|
|
| 9 |
- long-context
|
| 10 |
- flashmemory
|
| 11 |
datasets:
|
|
|
|
| 17 |
|
| 18 |
# FlashMemory DS-V4 Retriever
|
| 19 |
|
| 20 |
+
A lightweight retriever that sparsifies **DeepSeek-V4 CSA KV-cache**. Given a
|
| 21 |
+
decode-token hidden state, it predicts which compressed-K chunks the next
|
| 22 |
+
~64 tokens will attend to — keeping only those on GPU, offloading the rest.
|
| 23 |
|
| 24 |
+
In downstream evaluation it matches or beats full-attention baseline on
|
| 25 |
+
reasoning-heavy long-context tasks (**RULER, LongMemEval, LongBench V2**)
|
| 26 |
+
while reducing KV-cache usage by **~85–90%**. Precise needle-retrieval tasks
|
| 27 |
+
require an additional threshold-fallback mechanism (not in this release).
|
|
|
|
| 28 |
|
| 29 |
+
## Quick start
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
```bash
|
| 32 |
+
pip install torch safetensors
|
| 33 |
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
|
| 34 |
```
|
| 35 |
|
| 36 |
+
## Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
```python
|
|
|
|
| 39 |
from retriever import FlashMemoryRetriever
|
| 40 |
|
| 41 |
model = FlashMemoryRetriever.from_checkpoint(
|
| 42 |
+
"weights/flashmemory_ds_v4.safetensors", device="cuda"
|
| 43 |
)
|
| 44 |
|
| 45 |
+
# hidden: [B, 4096] decode hidden state
|
| 46 |
+
# compressed_k: [B, N, 132] uint8 CSA keys
|
| 47 |
+
# positions: [B] int64 token positions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
scores = model.ensemble(hidden, compressed_k, positions, mode="max") # [B, N]
|
| 50 |
+
keep = model.select_topk(hidden, compressed_k, positions, top_k=512) # boolean mask
|
|
|
|
| 51 |
```
|
| 52 |
|
| 53 |
+
**`compressed_k` format:** each chunk = 128 bytes `float8_e4m3` values + 4 bytes `float32` scale. See `make_mock_compressed_k()` in `demo.py`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
## Architecture
|
| 56 |
|
| 57 |
+
3-layer joint model (`l10`, `l12`, `l20`), 128 heads, 2048 LoRA rank. Per-layer
|
| 58 |
+
sigmoid scores are ensembled (`max` or `mean`) per chunk.
|
| 59 |
|
|
|
|
|
|
|
|
|
|
| 60 |
```
|
| 61 |
+
hidden [B,4096] → q-proj → RoPE(YaRN) → Hadamard → q [B,128,128]
|
| 62 |
+
→ weights_proj → fused_w [B,128]
|
| 63 |
+
compressed_k → FP8 dequant → k [B,N,128]
|
| 64 |
|
| 65 |
+
score = sigmoid( Σ( relu(k @ qᵀ) · fused_w ) ) ∈ [0,1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
```
|
| 67 |
|
| 68 |
+
## Toy inference reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
`toy_flashmemory_inference.py` illustrates how the retriever drives memory
|
| 71 |
+
recall during decode: every 64 steps it re-scores all chunks, and unselected
|
| 72 |
+
ones are masked from attention (equivalent to "not recalled to GPU").
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
```bash
|
| 75 |
python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
|
| 76 |
```
|
| 77 |
|
| 78 |
+
> The decoder is a few toy layers with random weights — it is **not** a real
|
| 79 |
+
> DeepSeek-V4. The retriever, scoring math, and decode-time control flow are real.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
## Files
|
| 82 |
|
| 83 |
+
| File | Purpose |
|
| 84 |
+
|------|---------|
|
| 85 |
+
| `retriever.py` | `FlashMemoryRetriever` model (torch-only, self-contained) |
|
| 86 |
+
| `demo.py` | minimal demo with mock inputs |
|
| 87 |
+
| `toy_flashmemory_inference.py` | toy sparse-decode loop |
|
| 88 |
+
| `weights/flashmemory_ds_v4.safetensors` | trained weights (~510 MB) |
|
| 89 |
|
| 90 |
## License
|
| 91 |
|
| 92 |
+
MIT
|