Commit Β·
640b654
0
Parent(s):
Initial release: FlashMemory DS-V4 Retriever
Browse files- FlashMemoryRetriever model (retriever.py)
- Minimal demo with mock inputs (demo.py)
- Toy sparse-decode inference reference (toy_flashmemory_inference.py)
- Model weights (flashmemory_ds_v4.safetensors, ~510 MB)
Co-Authored-By: Claude Code <noreply@anthropic.com>
- .gitattributes +1 -0
- LICENSE +21 -0
- README.md +320 -0
- demo.py +133 -0
- requirements.txt +3 -0
- retriever.py +505 -0
- toy_flashmemory_inference.py +312 -0
- weights/flashmemory_ds_v4.safetensors +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
weights/*.safetensors filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 FlashMemory Authors
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FlashMemory DS-V4 Retriever
|
| 2 |
+
|
| 3 |
+
A standalone, dependency-light reference implementation of the **FlashMemory DS-V4
|
| 4 |
+
Retriever** β a lightweight retriever that sparsifies the **DeepSeek-V4
|
| 5 |
+
Compressed-Sparse-Attention (CSA)** KV cache.
|
| 6 |
+
|
| 7 |
+
Given the hidden state of a decode token, the retriever predicts which CSA
|
| 8 |
+
KV-cache chunks (compressed keys) the upcoming tokens will attend to, so that
|
| 9 |
+
only the **top-scoring chunks** need to stay resident on the GPU and the rest can
|
| 10 |
+
be offloaded to CPU / disk. This recovers most of the quality of full attention
|
| 11 |
+
on long-context tasks while keeping a small fraction of the KV cache on-device.
|
| 12 |
+
|
| 13 |
+
This release contains the **algorithm + weights + a minimal, runnable PyTorch
|
| 14 |
+
demo**. It depends only on `torch` (plus `numpy` / `safetensors` for convenience).
|
| 15 |
+
|
| 16 |
+
> **Scope note.** The full sglang serving integration β KV-cache swap-in/out,
|
| 17 |
+
> attention-sink, threshold fallback, per-request retriever routing β is **not**
|
| 18 |
+
> included here, because it is tightly coupled to the internal DeepSeek-V4 CSA
|
| 19 |
+
> framework and cannot run outside it. This repository provides the retriever
|
| 20 |
+
> **algorithm reference implementation and trained weights only.**
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## Model architecture
|
| 25 |
+
|
| 26 |
+
The retriever scores each compressed-K chunk against the decode token's hidden
|
| 27 |
+
state. For a single CSA layer:
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
hidden [B, 4096]
|
| 31 |
+
β wq_a (4096 β Q_LORA_RANK)
|
| 32 |
+
β RMSNorm (q_norm_weight, eps=1e-6)
|
| 33 |
+
β wq_b (Q_LORA_RANK β N_HEADS * HEAD_DIM)
|
| 34 |
+
β reshape [B, N_HEADS, HEAD_DIM]
|
| 35 |
+
β RoPE (YaRN, applied to the last ROPE_DIM=64 dims, base=160000)
|
| 36 |
+
β Hadamard (normalized Walsh-Hadamard transform)
|
| 37 |
+
β q [B, N_HEADS, HEAD_DIM]
|
| 38 |
+
|
| 39 |
+
hidden [B, 4096]
|
| 40 |
+
β weights_proj (4096 β N_HEADS)
|
| 41 |
+
β Γ weight_scale (= HEAD_DIM^-0.5 * N_HEADS^-0.5)
|
| 42 |
+
β fused_w [B, N_HEADS]
|
| 43 |
+
|
| 44 |
+
compressed_k [B, N, HEAD_DIM + 4] (uint8)
|
| 45 |
+
β bytes[:HEAD_DIM] viewed as float8_e4m3 β dequantize
|
| 46 |
+
β bytes[HEAD_DIM:] viewed as float32 β per-chunk scale
|
| 47 |
+
β k [B, N, HEAD_DIM]
|
| 48 |
+
|
| 49 |
+
score_per_head = relu( einsum('bnd,bhd->bnh', k, q) ) # [B, N, N_HEADS]
|
| 50 |
+
logit = (score_per_head * fused_w[:, None, :]).sum(-1) # [B, N]
|
| 51 |
+
score = sigmoid(logit) β [0, 1] # [B, N]
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
**Hyperparameters (FlashMemory DS-V4):** `Q_LORA_RANK = 2048`, `N_HEADS = 128`,
|
| 55 |
+
`HEAD_DIM = 128`, `ROPE_DIM = 64`, `ROPE_BASE = 160000`, `ROPE_FACTOR = 16`,
|
| 56 |
+
`ROPE_ORIGINAL_SEQ_LEN = 65536`, `ROPE_BETA_FAST = 32`, `ROPE_BETA_SLOW = 1`,
|
| 57 |
+
`RMS_NORM_EPS = 1e-6`.
|
| 58 |
+
|
| 59 |
+
### Joint multi-layer checkpoint + ensemble
|
| 60 |
+
|
| 61 |
+
FlashMemory DS-V4 is a **joint checkpoint** holding three independent CSA layers
|
| 62 |
+
(`l10`, `l12`, `l20`), each with its own weights. At inference time the per-layer
|
| 63 |
+
sigmoid scores are **ensembled per chunk** β cross-layer `max` (default) or
|
| 64 |
+
`mean` β to produce a single keep/drop decision per chunk.
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## What is FlashMemory DS-V4?
|
| 69 |
+
|
| 70 |
+
FlashMemory DS-V4 is part of the latest retraining generation of these retrievers. In the
|
| 71 |
+
project's downstream evaluation it stays close to the full-attention baseline on
|
| 72 |
+
long-context tasks (e.g. RULER, LongMemEval, LongBench V2) while keeping only a
|
| 73 |
+
small fraction of the CSA KV cache on-device (β90% KV reduction in the deployment
|
| 74 |
+
sweet spot for reasoning-heavy long-context tasks). Precise-needle retrieval
|
| 75 |
+
tasks need an extra threshold-fallback mechanism in the serving layer (not part
|
| 76 |
+
of this standalone release).
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
## Installation
|
| 81 |
+
|
| 82 |
+
```bash
|
| 83 |
+
pip install -r requirements.txt
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
Only `torch` is strictly required to run the model and demo. `float8_e4m3`
|
| 87 |
+
tensor support requires a reasonably recent PyTorch (β₯ 2.1).
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
## Running the demo
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
The demo builds **random mock inputs** (a batch of decode-token hidden states, a
|
| 98 |
+
set of `uint8` compressed-K chunks, and token positions), loads the FlashMemory DS-V4
|
| 99 |
+
checkpoint, runs the forward pass, prints the per-layer and ensembled per-chunk
|
| 100 |
+
scores, and demonstrates both **threshold** and **top-K** chunk selection.
|
| 101 |
+
|
| 102 |
+
Useful flags:
|
| 103 |
+
|
| 104 |
+
| Flag | Default | Meaning |
|
| 105 |
+
|------|---------|---------|
|
| 106 |
+
| `--device` | `cpu` | `cpu` or `cuda` |
|
| 107 |
+
| `--batch` | `2` | number of decode tokens |
|
| 108 |
+
| `--n-chunks` | `64` | number of compressed-K chunks |
|
| 109 |
+
| `--top-k` | `16` | top-K chunks to select |
|
| 110 |
+
| `--threshold` | `0.5` | sigmoid keep threshold |
|
| 111 |
+
| `--ensemble` | `max` | cross-layer ensemble mode (`max` / `mean`) |
|
| 112 |
+
| `--max-position` | `524288` | RoPE table length (raise to `1048576` for 1M context) |
|
| 113 |
+
|
| 114 |
+
Example output (CPU, default args):
|
| 115 |
+
|
| 116 |
+
```
|
| 117 |
+
[demo] loaded layers=['l10', 'l12', 'l20'] n_heads=128 head_dim=128 max_position=524288
|
| 118 |
+
[demo] per-layer sigmoid score stats (over all chunks):
|
| 119 |
+
l10: min=0.4474 mean=0.5021 max=0.6416
|
| 120 |
+
...
|
| 121 |
+
[demo] threshold selection (sigmoid > 0.5):
|
| 122 |
+
row 0: keep 64/64 chunks (keep ratio 100.0%)
|
| 123 |
+
row 1: keep 49/64 chunks (keep ratio 76.6%)
|
| 124 |
+
[demo] done. β
forward + scoring + selection all ran.
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
> The scores above come from **random mock K**, so they cluster near 0.5 β they
|
| 128 |
+
> are only meaningful on real CSA keys. The demo's purpose is to verify the
|
| 129 |
+
> load β forward β selection path end-to-end.
|
| 130 |
+
|
| 131 |
+
---
|
| 132 |
+
|
| 133 |
+
## Using the model in your own code
|
| 134 |
+
|
| 135 |
+
```python
|
| 136 |
+
import torch
|
| 137 |
+
from retriever import FlashMemoryRetriever
|
| 138 |
+
|
| 139 |
+
model = FlashMemoryRetriever.from_checkpoint(
|
| 140 |
+
"weights/flashmemory_ds_v4.safetensors", device="cuda", max_position=524288
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
hidden = torch.randn(B, 4096, device="cuda") # decode-token hidden states
|
| 144 |
+
compressed_k = ... # [B, N, 132] uint8 CSA keys
|
| 145 |
+
positions = torch.arange(B, device="cuda") # int64 token positions
|
| 146 |
+
|
| 147 |
+
# Per-layer sigmoid scores: {"l10": [B, N], "l12": [B, N], "l20": [B, N]}
|
| 148 |
+
per_layer = model(hidden, compressed_k, positions)
|
| 149 |
+
|
| 150 |
+
# Cross-layer ensembled per-chunk scores [B, N] β [0, 1]
|
| 151 |
+
scores = model.ensemble(hidden, compressed_k, positions, mode="max")
|
| 152 |
+
|
| 153 |
+
# Boolean keep-mask [B, N] for the chunks to keep on-device
|
| 154 |
+
keep = model.select_topk(hidden, compressed_k, positions, top_k=512) # top-K
|
| 155 |
+
keep = model.select_topk(hidden, compressed_k, positions, threshold=0.5) # threshold
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
**`compressed_k` format.** Each chunk is `HEAD_DIM + 4 = 132` `uint8` bytes:
|
| 159 |
+
the first `128` bytes are the `float8_e4m3` quantized key values, the last `4`
|
| 160 |
+
bytes are a single `float32` per-chunk scale. Dequantization is
|
| 161 |
+
`fp8_values.view(float8_e4m3).float() * scale`. See `make_mock_compressed_k` in
|
| 162 |
+
`demo.py` for how to construct a valid tensor.
|
| 163 |
+
|
| 164 |
+
---
|
| 165 |
+
|
| 166 |
+
## Weights
|
| 167 |
+
|
| 168 |
+
**Download:** [Hugging Face](https://huggingface.co/<HF_REPO>) β `flashmemory_ds_v4.safetensors` (β510 MB).
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
huggingface-cli download <HF_REPO> flashmemory_ds_v4.safetensors --local-dir ./weights
|
| 172 |
+
python demo.py --ckpt ./weights/flashmemory_ds_v4.safetensors
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
`from_checkpoint` accepts either a `.pt` (`torch.save` state-dict) or a
|
| 176 |
+
`.safetensors` file. The released `.safetensors` is the **slim** form: it stores
|
| 177 |
+
only the four learned tensors per layer
|
| 178 |
+
(`wq_a.weight`, `wq_b.weight`, `q_norm_weight`, `weights_proj.weight` for
|
| 179 |
+
`l10` / `l12` / `l20`) and **omits the `freqs_cis` RoPE table** (β400 MB), which
|
| 180 |
+
is recomputed at load time from `max_position`. Loading the slim `.safetensors`
|
| 181 |
+
is bit-for-bit identical to loading the full `.pt` (verified by output match).
|
| 182 |
+
|
| 183 |
+
---
|
| 184 |
+
|
| 185 |
+
## Files
|
| 186 |
+
|
| 187 |
+
| File | Purpose |
|
| 188 |
+
|------|---------|
|
| 189 |
+
| `retriever.py` | `FlashMemoryRetriever` model + RoPE/Hadamard utils + FP8 dequant (torch-only, self-contained) |
|
| 190 |
+
| `demo.py` | minimal runnable demo with mock inputs |
|
| 191 |
+
| `toy_flashmemory_inference.py` | toy DeepSeek-V4-FlashMemory sparse-decode loop showing **how the retriever drives memory recall at inference time** (see below) |
|
| 192 |
+
| `requirements.txt` | `torch`, `safetensors`, `numpy` |
|
| 193 |
+
| `LICENSE` | MIT |
|
| 194 |
+
|
| 195 |
+
---
|
| 196 |
+
|
| 197 |
+
## Toy FlashMemory inference reference (`toy_flashmemory_inference.py`)
|
| 198 |
+
|
| 199 |
+
`demo.py` shows a single `hidden β scores` call. `toy_flashmemory_inference.py`
|
| 200 |
+
is the **next step up**: a tiny, fully-runnable illustration of *how the Lightning
|
| 201 |
+
Indexer Retriever is used inside a DeepSeek-V4-FlashMemory style sparse-decode
|
| 202 |
+
loop* to drive "memory recall".
|
| 203 |
+
|
| 204 |
+
It is intentionally small and pedagogical. It depends only on `torch` and the
|
| 205 |
+
sibling `retriever.py`, and it **reuses the real FlashMemory DS-V4 retriever verbatim** β none
|
| 206 |
+
of the scoring math is re-implemented.
|
| 207 |
+
|
| 208 |
+
### The inference flow it demonstrates
|
| 209 |
+
|
| 210 |
+
```
|
| 211 |
+
ββββββββββββ compress & store ββββββββββββββββββββββββββββββ
|
| 212 |
+
β PREFILL β historical K/V β CSA KV-cache (the memory) β
|
| 213 |
+
β (dense β βββββββββββββββββββΊ β N compressed chunks, β
|
| 214 |
+
β attn) β β each = [132] uint8 fp8-K β
|
| 215 |
+
ββββββ¬ββββββ ββββββββββββββββ¬ββββββββββββββ
|
| 216 |
+
β last hidden state β scored every 64 steps
|
| 217 |
+
βΌ β
|
| 218 |
+
βββββββββββββββββββββββββ DECODE LOOP ββββββββββΌβββββββββββββββββββββββββββ
|
| 219 |
+
β for each decode step t: β β
|
| 220 |
+
β hidden = toy_decoder.step(token, keep_mask) β (sparse memory attn) β
|
| 221 |
+
β β β
|
| 222 |
+
β every RETRIEVAL_INTERVAL (= 64) steps: βΌ β
|
| 223 |
+
β scores[N] = retriever.ensemble(hidden, compressed_k, pos) β
|
| 224 |
+
β keep_mask[N] = top-K (or sigmoid > threshold) of scores β
|
| 225 |
+
β β chunks NOT kept are masked to -inf in the next 64 decode steps β
|
| 226 |
+
β of memory attention (== "not recalled onto the GPU") β
|
| 227 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
1. **Prefill (dense).** A short prompt is run through dense memory attention. Its
|
| 231 |
+
last hidden state seeds the first retrieval cycle (the indexer needs a query
|
| 232 |
+
hidden state to score against). In a real run, prefill is also where the
|
| 233 |
+
historical KV is compressed into the `[N, 132]` `uint8` CSA chunks.
|
| 234 |
+
2. **Decode loop.** Every step the toy decoder produces a `[B, 4096]` hidden state
|
| 235 |
+
and attends over the `N` memory chunks.
|
| 236 |
+
3. **Retrieval cycle (every 64 steps).** The real `FlashMemoryRetriever` scores all
|
| 237 |
+
`N` compressed-K chunks against the current decode hidden state, ensembles the
|
| 238 |
+
per-layer (`l10`/`l12`/`l20`) sigmoid scores, and selects the chunks to keep β
|
| 239 |
+
either **top-K** or **sigmoid > threshold**. This predicts which chunks the
|
| 240 |
+
*next ~64 tokens* will attend to.
|
| 241 |
+
4. **Sparse attention.** For the next 64 steps, chunks **not** selected have their
|
| 242 |
+
memory-attention logits set to `-inf`, so they contribute nothing.
|
| 243 |
+
|
| 244 |
+
### What the masking simulates (important)
|
| 245 |
+
|
| 246 |
+
* This toy does **not** perform any real CPUβGPU KV-cache transfer. The swap-in /
|
| 247 |
+
swap-out machinery is part of the internal FlashMemory engineering and is **not**
|
| 248 |
+
included in this release.
|
| 249 |
+
* We **simulate memory recall by masking the FlashMemory Retriever's per-chunk
|
| 250 |
+
decisions**: a chunk the retriever did not select gets its attention logit set
|
| 251 |
+
to `-inf`. This is equivalent to *"that chunk's KV was never recalled onto the
|
| 252 |
+
GPU, so it cannot be attended to"* β for the attention output, masking a chunk
|
| 253 |
+
out and never loading it produce the same result.
|
| 254 |
+
* The toy's purpose is to make the **decode-time control flow** concrete: where the
|
| 255 |
+
retriever fires, what it consumes (decode hidden state + compressed CSA keys),
|
| 256 |
+
what it produces (a keep/drop mask), and how that mask sparsifies the next
|
| 257 |
+
window of decode steps.
|
| 258 |
+
|
| 259 |
+
### What it is / is NOT
|
| 260 |
+
|
| 261 |
+
* **IS:** a minimal, torch-only illustration of the decode-time control flow that
|
| 262 |
+
drives memory recall with the real FlashMemory DS-V4 retriever.
|
| 263 |
+
* **IS NOT:** a runnable DeepSeek-V4. The "decoder" is a couple of layers of
|
| 264 |
+
randomly-initialized toy attention/MLP whose only jobs are (a) to emit a
|
| 265 |
+
`[B, 4096]` hidden state for the retriever and (b) to own a memory attention we
|
| 266 |
+
can sparsify. The generated tokens are meaningless.
|
| 267 |
+
|
| 268 |
+
> **The production version cannot be released.** It depends on the internal sglang
|
| 269 |
+
> + DeepSeek-V4 CSA framework (native FP8 indexer, real compressed KV-cache,
|
| 270 |
+
> attention-sink, threshold fallback, per-request routing, and the actual KV swap
|
| 271 |
+
> engine). This file shows the *algorithmic role* of the retriever only.
|
| 272 |
+
|
| 273 |
+
### Run
|
| 274 |
+
|
| 275 |
+
```bash
|
| 276 |
+
python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
Runs on CPU by default; pass `--device cuda` for GPU.
|
| 280 |
+
|
| 281 |
+
| Flag | Default | Meaning |
|
| 282 |
+
|------|---------|---------|
|
| 283 |
+
| `--n-chunks` | `256` | number of CSA memory chunks (the long history) |
|
| 284 |
+
| `--steps` | `192` | decode steps to generate |
|
| 285 |
+
| `--retrieval-interval` | `64` | run the retriever every N steps (FlashMemory default) |
|
| 286 |
+
| `--select-mode` | `topk` | `topk` or `threshold` |
|
| 287 |
+
| `--top-k` | `64` | chunks to recall per cycle (`select-mode=topk`) |
|
| 288 |
+
| `--threshold` | `0.5` | sigmoid keep threshold (`select-mode=threshold`) |
|
| 289 |
+
| `--ensemble` | `max` | cross-layer ensemble mode (`max` / `mean`) |
|
| 290 |
+
| `--batch` | `1` | parallel decode sequences |
|
| 291 |
+
|
| 292 |
+
Example output (CPU, default args β `top-K=64` out of `256` chunks):
|
| 293 |
+
|
| 294 |
+
```
|
| 295 |
+
FlashMemory DS-V4 β toy sparse-decode loop
|
| 296 |
+
[load] weights/flashmemory_ds_v4.safetensors
|
| 297 |
+
[load] layers=['l10', 'l12', 'l20'] n_heads=128 head_dim=128
|
| 298 |
+
[init] decoder: 2 layers, 8 heads | CSA memory: 256 chunks [132] uint8
|
| 299 |
+
|
| 300 |
+
[decode] 192 steps, retriever every 64 steps (topk [top-K=64], ensemble=max)
|
| 301 |
+
------------------------------------------------------------
|
| 302 |
+
[cycle 0] pos 8..71 | keep 25.0% (64/256) | score mean=0.4910 max=0.5445
|
| 303 |
+
[cycle 1] pos 72..135 | keep 25.0% (64/256) | score mean=0.4910 max=0.5445
|
| 304 |
+
...
|
| 305 |
+
------------------------------------------------------------
|
| 306 |
+
[done] 192 tokens, 3 cycles, avg keep/cycle: 25.0% β ~75% CSA KV dropped
|
| 307 |
+
[note] Dropped chunks are masked to -inf in attention (= KV not recalled to GPU).
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
> As in `demo.py`, the scores come from **random mock K** and cluster near 0.5;
|
| 311 |
+
> they are only meaningful on real CSA keys. The toy's value is the *control flow*
|
| 312 |
+
> β watch each retrieval cycle report how many chunks were scored, recalled, and
|
| 313 |
+
> masked out.
|
| 314 |
+
|
| 315 |
+
---
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
## License
|
| 319 |
+
|
| 320 |
+
MIT β see [`LICENSE`](./LICENSE).
|
demo.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
demo.py β minimal standalone demo for the FlashMemory DS-V4 Retriever
|
| 3 |
+
=====================================================================
|
| 4 |
+
|
| 5 |
+
Builds random mock inputs, loads the FlashMemory DS-V4 joint checkpoint, runs
|
| 6 |
+
a forward pass, and prints per-chunk scores plus a top-K selection summary.
|
| 7 |
+
|
| 8 |
+
Run::
|
| 9 |
+
|
| 10 |
+
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
|
| 11 |
+
|
| 12 |
+
Runs on CPU by default; pass ``--device cuda`` to use a GPU.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from retriever import FlashMemoryRetriever, dequant_compressed_k
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def make_mock_compressed_k(
|
| 25 |
+
batch: int,
|
| 26 |
+
n_chunks: int,
|
| 27 |
+
head_dim: int = 128,
|
| 28 |
+
device: str = "cpu",
|
| 29 |
+
seed: int = 0,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
"""Construct a valid mock ``compressed_k`` tensor [B, N, head_dim + 4] uint8.
|
| 32 |
+
|
| 33 |
+
Layout per chunk: ``head_dim`` float8_e4m3 bytes followed by one float32 scale
|
| 34 |
+
(4 bytes). We build it the same way the real CSA cache stores it:
|
| 35 |
+
1. sample random key vectors, cast to float8_e4m3, view as uint8;
|
| 36 |
+
2. sample a small positive per-chunk scale, view its float32 as 4 uint8 bytes;
|
| 37 |
+
3. concatenate along the last dim.
|
| 38 |
+
"""
|
| 39 |
+
g = torch.Generator(device=device).manual_seed(seed)
|
| 40 |
+
|
| 41 |
+
# 1) fp8 key bytes
|
| 42 |
+
k_vals = torch.randn(batch, n_chunks, head_dim, generator=g, device=device) * 0.5
|
| 43 |
+
k_fp8 = k_vals.to(torch.float8_e4m3fn)
|
| 44 |
+
fp8_bytes = k_fp8.view(torch.uint8) # [B, N, head_dim]
|
| 45 |
+
|
| 46 |
+
# 2) float32 per-chunk scale β 4 bytes
|
| 47 |
+
scale = (0.05 + 0.15 * torch.rand(batch, n_chunks, 1, generator=g, device=device)).float()
|
| 48 |
+
scale_bytes = scale.view(torch.uint8) # [B, N, 4]
|
| 49 |
+
|
| 50 |
+
compressed = torch.cat([fp8_bytes, scale_bytes], dim=-1) # [B, N, head_dim + 4]
|
| 51 |
+
assert compressed.shape[-1] == head_dim + 4
|
| 52 |
+
return compressed.contiguous()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main():
|
| 56 |
+
ap = argparse.ArgumentParser(description="FlashMemory DS-V4 Retriever demo")
|
| 57 |
+
ap.add_argument("--ckpt", required=True, help="path to joint checkpoint (.pt)")
|
| 58 |
+
ap.add_argument("--device", default="cpu", help="cpu or cuda (default: cpu)")
|
| 59 |
+
ap.add_argument("--batch", type=int, default=2, help="number of decode tokens")
|
| 60 |
+
ap.add_argument("--n-chunks", type=int, default=64, help="number of compressed-K chunks")
|
| 61 |
+
ap.add_argument("--max-position", type=int, default=524288,
|
| 62 |
+
help="RoPE table length (raise to 1048576 for 1M context)")
|
| 63 |
+
ap.add_argument("--top-k", type=int, default=16, help="top-K chunks to select")
|
| 64 |
+
ap.add_argument("--threshold", type=float, default=0.5, help="sigmoid keep threshold")
|
| 65 |
+
ap.add_argument("--ensemble", default="max", choices=["max", "mean"],
|
| 66 |
+
help="cross-layer ensemble mode")
|
| 67 |
+
ap.add_argument("--seed", type=int, default=0)
|
| 68 |
+
args = ap.parse_args()
|
| 69 |
+
|
| 70 |
+
torch.manual_seed(args.seed)
|
| 71 |
+
device = args.device
|
| 72 |
+
|
| 73 |
+
print(f"[demo] loading checkpoint: {args.ckpt}")
|
| 74 |
+
model = FlashMemoryRetriever.from_checkpoint(
|
| 75 |
+
args.ckpt, device=device, max_position=args.max_position
|
| 76 |
+
)
|
| 77 |
+
model.eval()
|
| 78 |
+
print(f"[demo] loaded layers={model.layer_names} n_heads={model.n_heads} "
|
| 79 |
+
f"head_dim={model.head_dim} max_position={model.max_position}")
|
| 80 |
+
|
| 81 |
+
# ββ Mock inputs βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
B, N = args.batch, args.n_chunks
|
| 83 |
+
hidden = torch.randn(B, 4096, device=device, dtype=torch.float32)
|
| 84 |
+
compressed_k = make_mock_compressed_k(B, N, head_dim=model.head_dim,
|
| 85 |
+
device=device, seed=args.seed)
|
| 86 |
+
# token positions for each decode token (arbitrary; here spaced out)
|
| 87 |
+
positions = torch.arange(B, device=device, dtype=torch.int64) * 1000 + 4096
|
| 88 |
+
|
| 89 |
+
print(f"\n[demo] mock inputs: hidden={tuple(hidden.shape)} "
|
| 90 |
+
f"compressed_k={tuple(compressed_k.shape)} ({compressed_k.dtype}) "
|
| 91 |
+
f"positions={positions.tolist()}")
|
| 92 |
+
|
| 93 |
+
# sanity: show dequant works
|
| 94 |
+
k_float = dequant_compressed_k(compressed_k, head_dim=model.head_dim)
|
| 95 |
+
print(f"[demo] dequantized K: shape={tuple(k_float.shape)} "
|
| 96 |
+
f"mean={k_float.mean().item():+.4f} std={k_float.std().item():.4f}")
|
| 97 |
+
|
| 98 |
+
# ββ Per-layer scores ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 99 |
+
per_layer = model(hidden, compressed_k, positions, apply_sigmoid=True)
|
| 100 |
+
print("\n[demo] per-layer sigmoid score stats (over all chunks):")
|
| 101 |
+
for name, s in per_layer.items():
|
| 102 |
+
print(f" {name}: min={s.min().item():.4f} mean={s.mean().item():.4f} "
|
| 103 |
+
f"max={s.max().item():.4f}")
|
| 104 |
+
|
| 105 |
+
# ββ Cross-layer ensemble ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 106 |
+
scores = model.ensemble(hidden, compressed_k, positions, mode=args.ensemble) # [B, N]
|
| 107 |
+
print(f"\n[demo] ensembled ({args.ensemble}) per-chunk scores [B={B}, N={N}]:")
|
| 108 |
+
for b in range(B):
|
| 109 |
+
row = scores[b]
|
| 110 |
+
preview = ", ".join(f"{v:.3f}" for v in row[:12].tolist())
|
| 111 |
+
print(f" row {b}: [{preview}{', ...' if N > 12 else ''}]")
|
| 112 |
+
|
| 113 |
+
# ββ Selection: threshold ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 114 |
+
keep_thr = model.select_topk(hidden, compressed_k, positions,
|
| 115 |
+
threshold=args.threshold, mode=args.ensemble)
|
| 116 |
+
print(f"\n[demo] threshold selection (sigmoid > {args.threshold}):")
|
| 117 |
+
for b in range(B):
|
| 118 |
+
n_keep = int(keep_thr[b].sum().item())
|
| 119 |
+
print(f" row {b}: keep {n_keep}/{N} chunks (keep ratio {n_keep / N:.1%})")
|
| 120 |
+
|
| 121 |
+
# ββ Selection: top-K ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 122 |
+
keep_topk = model.select_topk(hidden, compressed_k, positions,
|
| 123 |
+
top_k=args.top_k, mode=args.ensemble)
|
| 124 |
+
print(f"\n[demo] top-K selection (k={args.top_k}):")
|
| 125 |
+
for b in range(B):
|
| 126 |
+
idx = keep_topk[b].nonzero(as_tuple=True)[0].tolist()
|
| 127 |
+
print(f" row {b}: kept chunk indices = {idx}")
|
| 128 |
+
|
| 129 |
+
print("\n[demo] done. β
forward + scoring + selection all ran.")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.1
|
| 2 |
+
safetensors
|
| 3 |
+
numpy
|
retriever.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
retriever.py β FlashMemory DS-V4 Retriever (standalone reference implementation)
|
| 3 |
+
===============================================================================
|
| 4 |
+
|
| 5 |
+
A self-contained, dependency-light (torch only) PyTorch reference implementation
|
| 6 |
+
of the **FlashMemory Retriever** used for sparsifying the DeepSeek-V4
|
| 7 |
+
Compressed-Sparse-Attention (CSA) KV cache.
|
| 8 |
+
|
| 9 |
+
Given the hidden state of a decode token, the retriever predicts which CSA
|
| 10 |
+
KV-cache chunks the next tokens will attend to, so that only the top-scoring
|
| 11 |
+
chunks need to stay resident on the GPU.
|
| 12 |
+
|
| 13 |
+
compressed_k [B, N, 132] uint8 β dequant β k [B, N, HEAD_DIM]
|
| 14 |
+
hidden [B, 4096] β q-proj + RoPE + Hadamard β q [B, N_HEADS, HEAD_DIM]
|
| 15 |
+
β weights_proj β fused_w [B, N_HEADS]
|
| 16 |
+
|
| 17 |
+
score = sigmoid( (relu(k @ q^T) Β· fused_w).sum(heads) ) β [0, 1]
|
| 18 |
+
|
| 19 |
+
The shipped checkpoint is a *joint* checkpoint holding three independent CSA
|
| 20 |
+
layers (l10 / l12 / l20). At inference time the per-layer sigmoid scores are
|
| 21 |
+
ensembled per chunk (cross-layer ``max`` by default, ``mean`` also supported).
|
| 22 |
+
|
| 23 |
+
This file only depends on ``torch``. The full sglang serving integration
|
| 24 |
+
(KV-cache swap, attention-sink, threshold fallback, per-request routing) is
|
| 25 |
+
NOT part of this open release because it depends on the internal DeepSeek-V4
|
| 26 |
+
CSA framework.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import math
|
| 32 |
+
from collections import OrderedDict
|
| 33 |
+
from typing import Dict, List, Optional, Union
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import torch.nn.functional as F
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
# RoPE (YaRN) + Hadamard utilities
|
| 42 |
+
# (copied from the project's utils.py so this release is self-contained)
|
| 43 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _yarn_find_correction_dim(n_rot: float, d_model: int, base: float, max_pos: int) -> float:
|
| 47 |
+
return (d_model * math.log(max_pos / (n_rot * 2 * math.pi))) / (2 * math.log(base))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def precompute_freqs_cis(
|
| 51 |
+
dim: int,
|
| 52 |
+
seqlen: int,
|
| 53 |
+
base: float,
|
| 54 |
+
factor: float,
|
| 55 |
+
original_seq_len: int,
|
| 56 |
+
beta_fast: float,
|
| 57 |
+
beta_slow: float,
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
"""YaRN RoPE frequency precomputation.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
freqs_cis: [seqlen, dim // 2] complex64
|
| 63 |
+
"""
|
| 64 |
+
low = max(math.floor(_yarn_find_correction_dim(beta_fast, dim, base, original_seq_len)), 0)
|
| 65 |
+
high = min(math.ceil(_yarn_find_correction_dim(beta_slow, dim, base, original_seq_len)), dim // 2 - 1)
|
| 66 |
+
|
| 67 |
+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) # [dim//2]
|
| 68 |
+
|
| 69 |
+
ramp = torch.zeros(dim // 2)
|
| 70 |
+
for i in range(dim // 2):
|
| 71 |
+
if i < low:
|
| 72 |
+
ramp[i] = 0.0
|
| 73 |
+
elif i >= high:
|
| 74 |
+
ramp[i] = 1.0
|
| 75 |
+
else:
|
| 76 |
+
ramp[i] = (i - low) / max(high - low, 1)
|
| 77 |
+
|
| 78 |
+
mixed = freqs * (1 - ramp) + (freqs / factor) * ramp # [dim//2]
|
| 79 |
+
t = torch.arange(seqlen, dtype=torch.float32)
|
| 80 |
+
angles = torch.outer(t, mixed) # [seqlen, dim//2]
|
| 81 |
+
return torch.polar(torch.ones_like(angles), angles) # complex64
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def apply_rope(
|
| 85 |
+
q: torch.Tensor,
|
| 86 |
+
freqs_cis: torch.Tensor,
|
| 87 |
+
positions: torch.Tensor,
|
| 88 |
+
rope_dim: int = 64,
|
| 89 |
+
) -> torch.Tensor:
|
| 90 |
+
"""Pure-PyTorch RoPE applied to the last ``rope_dim`` dims of ``q``.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
q: [B, n_heads, head_dim]
|
| 94 |
+
freqs_cis: [max_pos, rope_dim // 2] complex64
|
| 95 |
+
positions: [B] int64
|
| 96 |
+
rope_dim: number of trailing dims to rotate (applied to q[..., -rope_dim:])
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
q after RoPE, same shape as input.
|
| 100 |
+
"""
|
| 101 |
+
head_dim = q.shape[-1]
|
| 102 |
+
q_pass = q[..., : head_dim - rope_dim]
|
| 103 |
+
q_rope = q[..., head_dim - rope_dim:]
|
| 104 |
+
|
| 105 |
+
q_c = torch.view_as_complex(
|
| 106 |
+
q_rope.float().reshape(*q_rope.shape[:-1], rope_dim // 2, 2).contiguous()
|
| 107 |
+
) # [B, H, rope_dim//2]
|
| 108 |
+
|
| 109 |
+
# Clamp positions into the RoPE table range. The freqs_cis table covers
|
| 110 |
+
# max_position entries; tokens beyond it get clamped to the last entry
|
| 111 |
+
# (YaRN extrapolation already makes the tail an approximation, so a few
|
| 112 |
+
# clamped ultra-long positions are far better than an out-of-bounds gather).
|
| 113 |
+
positions = positions.clamp(0, freqs_cis.shape[0] - 1)
|
| 114 |
+
|
| 115 |
+
freqs = freqs_cis[positions].unsqueeze(1) # [B, 1, rope_dim//2]
|
| 116 |
+
q_rot = torch.view_as_real(q_c * freqs).reshape(*q_rope.shape).to(q.dtype)
|
| 117 |
+
return torch.cat([q_pass, q_rot], dim=-1)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def hadamard_transform(x: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
"""Normalized Walsh-Hadamard transform over the last dim (must be a power of 2).
|
| 122 |
+
|
| 123 |
+
x: [..., d] β [..., d] (normalized by 1/sqrt(d))
|
| 124 |
+
"""
|
| 125 |
+
*leading, d = x.shape
|
| 126 |
+
assert d > 0 and (d & (d - 1)) == 0, f"last dim {d} must be a power of 2"
|
| 127 |
+
h = x.float()
|
| 128 |
+
s = 1
|
| 129 |
+
while s < d:
|
| 130 |
+
h = h.view(*leading, d // (2 * s), 2, s)
|
| 131 |
+
a, b = h[..., 0, :], h[..., 1, :]
|
| 132 |
+
h = torch.stack([a + b, a - b], dim=-2).view(*leading, d)
|
| 133 |
+
s *= 2
|
| 134 |
+
return h / math.sqrt(d)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
# compressed-K dequantization
|
| 139 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def dequant_compressed_k(compressed_k: torch.Tensor, head_dim: int = 128) -> torch.Tensor:
|
| 143 |
+
"""Dequantize compressed CSA keys.
|
| 144 |
+
|
| 145 |
+
Each compressed key is ``head_dim + 4`` bytes:
|
| 146 |
+
bytes[:head_dim] β float8_e4m3 quantized key values (1 byte each)
|
| 147 |
+
bytes[head_dim:+4] β a single float32 per-chunk scale
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
compressed_k: [..., head_dim + 4] uint8
|
| 151 |
+
head_dim: number of key dims (default 128)
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
k: [..., head_dim] float32 ( = fp8_values * scale )
|
| 155 |
+
"""
|
| 156 |
+
assert compressed_k.dtype == torch.uint8, (
|
| 157 |
+
f"compressed_k must be uint8, got {compressed_k.dtype}"
|
| 158 |
+
)
|
| 159 |
+
assert compressed_k.shape[-1] == head_dim + 4, (
|
| 160 |
+
f"compressed_k last dim must be {head_dim + 4}, got {compressed_k.shape[-1]}"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
fp8_bytes = compressed_k[..., :head_dim].contiguous() # uint8 [..., head_dim]
|
| 164 |
+
k_fp8 = fp8_bytes.view(torch.float8_e4m3fn).float() # [..., head_dim]
|
| 165 |
+
|
| 166 |
+
scale_bytes = compressed_k[..., head_dim:head_dim + 4].contiguous() # uint8 [..., 4]
|
| 167 |
+
scale = scale_bytes.view(torch.float32) # [..., 1]
|
| 168 |
+
|
| 169 |
+
return k_fp8 * scale # broadcast β [..., head_dim]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 173 |
+
# per-layer scorer module
|
| 174 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class _LayerScorer(nn.Module):
|
| 178 |
+
"""Holds one CSA layer's retriever weights and computes its logits.
|
| 179 |
+
|
| 180 |
+
Weights are stored as (non-trainable) buffers so ``.to(device)`` / ``.half()``
|
| 181 |
+
move them along with the parent module.
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
wq_a: torch.Tensor, # [Q_LORA_RANK, 4096]
|
| 187 |
+
wq_b: torch.Tensor, # [N_HEADS * HEAD_DIM, Q_LORA_RANK]
|
| 188 |
+
q_norm_weight: torch.Tensor, # [Q_LORA_RANK]
|
| 189 |
+
weights_proj: torch.Tensor, # [N_HEADS, 4096]
|
| 190 |
+
n_heads: int,
|
| 191 |
+
head_dim: int,
|
| 192 |
+
rope_dim: int,
|
| 193 |
+
rms_norm_eps: float,
|
| 194 |
+
weight_scale: float,
|
| 195 |
+
):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.register_buffer("wq_a", wq_a.to(torch.float32), persistent=False)
|
| 198 |
+
self.register_buffer("wq_b", wq_b.to(torch.float32), persistent=False)
|
| 199 |
+
self.register_buffer("q_norm_weight", q_norm_weight.to(torch.float32), persistent=False)
|
| 200 |
+
self.register_buffer("weights_proj", weights_proj.to(torch.float32), persistent=False)
|
| 201 |
+
self.n_heads = n_heads
|
| 202 |
+
self.head_dim = head_dim
|
| 203 |
+
self.rope_dim = rope_dim
|
| 204 |
+
self.rms_norm_eps = rms_norm_eps
|
| 205 |
+
self.weight_scale = weight_scale
|
| 206 |
+
|
| 207 |
+
def _rmsnorm(self, x: torch.Tensor) -> torch.Tensor:
|
| 208 |
+
x_f = x.float()
|
| 209 |
+
norm = torch.sqrt(x_f.pow(2).mean(dim=-1, keepdim=True) + self.rms_norm_eps)
|
| 210 |
+
return x_f / norm * self.q_norm_weight
|
| 211 |
+
|
| 212 |
+
@torch.no_grad()
|
| 213 |
+
def logits(
|
| 214 |
+
self,
|
| 215 |
+
hidden: torch.Tensor, # [B, 4096]
|
| 216 |
+
k_float: torch.Tensor, # [B, N, head_dim] (already dequantized)
|
| 217 |
+
positions: torch.Tensor, # [B] int64
|
| 218 |
+
freqs_cis: torch.Tensor, # [max_pos, rope_dim//2] complex64
|
| 219 |
+
) -> torch.Tensor:
|
| 220 |
+
"""Return raw (pre-sigmoid) logits [B, N] for this layer."""
|
| 221 |
+
x = hidden.float()
|
| 222 |
+
B = x.shape[0]
|
| 223 |
+
|
| 224 |
+
# ββ Q side ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 225 |
+
q_lora = F.linear(x, self.wq_a) # [B, Q_LORA_RANK]
|
| 226 |
+
q_lora = self._rmsnorm(q_lora) # [B, Q_LORA_RANK]
|
| 227 |
+
q = F.linear(q_lora, self.wq_b) # [B, N_HEADS * HEAD_DIM]
|
| 228 |
+
q = q.view(B, self.n_heads, self.head_dim) # [B, N_HEADS, HEAD_DIM]
|
| 229 |
+
# RoPE is applied in bf16 then cast back to float32 to match the trained
|
| 230 |
+
# / deployed scoring path exactly.
|
| 231 |
+
q = apply_rope(q.to(torch.bfloat16), freqs_cis, positions.to(torch.int64),
|
| 232 |
+
rope_dim=self.rope_dim).float()
|
| 233 |
+
q = hadamard_transform(q) # [B, N_HEADS, HEAD_DIM]
|
| 234 |
+
|
| 235 |
+
per_head_w = F.linear(x, self.weights_proj) # [B, N_HEADS]
|
| 236 |
+
fused_w = per_head_w * self.weight_scale # [B, N_HEADS]
|
| 237 |
+
|
| 238 |
+
# ββ Score: relu(k @ q^T) weighted-sum over heads ββββββββββββββββββββ
|
| 239 |
+
# q: [B, H, D], k_float: [B, N, D] β [B, N, H]
|
| 240 |
+
scores_per_head = F.relu(torch.einsum("bhd,bnd->bnh", q, k_float)) # [B, N, H]
|
| 241 |
+
logits = (scores_per_head * fused_w.unsqueeze(1)).sum(-1) # [B, N]
|
| 242 |
+
return logits
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 246 |
+
# FlashMemoryRetriever
|
| 247 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class FlashMemoryRetriever(nn.Module):
|
| 251 |
+
"""Multi-layer FlashMemory retriever (joint checkpoint).
|
| 252 |
+
|
| 253 |
+
Loads a joint checkpoint whose state-dict keys look like
|
| 254 |
+
``retrievers.l10.wq_a.weight``, builds one ``_LayerScorer`` per CSA layer,
|
| 255 |
+
and scores compressed-K chunks against a decode token's hidden state.
|
| 256 |
+
|
| 257 |
+
Typical usage::
|
| 258 |
+
|
| 259 |
+
model = FlashMemoryRetriever.from_checkpoint("flashmemory_ds_v4.safetensors",
|
| 260 |
+
device="cuda")
|
| 261 |
+
per_layer = model(hidden_state, compressed_k, positions) # {"l10": [B,N], ...}
|
| 262 |
+
scores = model.ensemble(hidden_state, compressed_k, positions, mode="max") # [B,N]
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
# RoPE / normalization constants (identical across all CSA layers).
|
| 266 |
+
HEAD_DIM = 128
|
| 267 |
+
ROPE_DIM = 64
|
| 268 |
+
ROPE_BASE = 160000.0
|
| 269 |
+
ROPE_FACTOR = 16.0
|
| 270 |
+
ROPE_ORIGINAL_SEQ_LEN = 65536
|
| 271 |
+
ROPE_BETA_FAST = 32.0
|
| 272 |
+
ROPE_BETA_SLOW = 1.0
|
| 273 |
+
RMS_NORM_EPS = 1e-6
|
| 274 |
+
|
| 275 |
+
def __init__(
|
| 276 |
+
self,
|
| 277 |
+
layer_states: "OrderedDict[str, Dict[str, torch.Tensor]]",
|
| 278 |
+
device: Union[str, torch.device] = "cpu",
|
| 279 |
+
max_position: int = 524288,
|
| 280 |
+
head_dim: Optional[int] = None,
|
| 281 |
+
):
|
| 282 |
+
"""
|
| 283 |
+
Args:
|
| 284 |
+
layer_states: ordered mapping ``layer_name -> {"wq_a.weight": ...,
|
| 285 |
+
"wq_b.weight": ..., "q_norm_weight": ..., "weights_proj.weight": ...}``.
|
| 286 |
+
Layer names are arbitrary (e.g. ``"l10"``); ordering is preserved.
|
| 287 |
+
device: device to place the model on.
|
| 288 |
+
max_position: RoPE table length. Must cover the largest token position
|
| 289 |
+
ever scored; positions beyond it are clamped (RoPE becomes an
|
| 290 |
+
approximation). Default 524288; can be raised to 1_048_576 (1M) for
|
| 291 |
+
full-length DeepSeek-V4 contexts.
|
| 292 |
+
head_dim: key/head dimension. Defaults to ``HEAD_DIM`` (128).
|
| 293 |
+
"""
|
| 294 |
+
super().__init__()
|
| 295 |
+
assert layer_states, "FlashMemoryRetriever needs at least one layer"
|
| 296 |
+
device = torch.device(device)
|
| 297 |
+
self.head_dim = head_dim if head_dim is not None else self.HEAD_DIM
|
| 298 |
+
self.max_position = max_position
|
| 299 |
+
self.layer_names: List[str] = list(layer_states.keys())
|
| 300 |
+
|
| 301 |
+
# Precompute the (shared) YaRN RoPE table once.
|
| 302 |
+
freqs_cis = precompute_freqs_cis(
|
| 303 |
+
dim=self.ROPE_DIM,
|
| 304 |
+
seqlen=max_position,
|
| 305 |
+
base=self.ROPE_BASE,
|
| 306 |
+
factor=self.ROPE_FACTOR,
|
| 307 |
+
original_seq_len=self.ROPE_ORIGINAL_SEQ_LEN,
|
| 308 |
+
beta_fast=self.ROPE_BETA_FAST,
|
| 309 |
+
beta_slow=self.ROPE_BETA_SLOW,
|
| 310 |
+
)
|
| 311 |
+
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
| 312 |
+
|
| 313 |
+
# Build one scorer per layer.
|
| 314 |
+
self.scorers = nn.ModuleDict()
|
| 315 |
+
for name, st in layer_states.items():
|
| 316 |
+
wq_b = st["wq_b.weight"]
|
| 317 |
+
n_heads = wq_b.shape[0] // self.head_dim
|
| 318 |
+
weight_scale = self.head_dim ** -0.5 * n_heads ** -0.5
|
| 319 |
+
self.scorers[name] = _LayerScorer(
|
| 320 |
+
wq_a=st["wq_a.weight"],
|
| 321 |
+
wq_b=wq_b,
|
| 322 |
+
q_norm_weight=st["q_norm_weight"],
|
| 323 |
+
weights_proj=st["weights_proj.weight"],
|
| 324 |
+
n_heads=n_heads,
|
| 325 |
+
head_dim=self.head_dim,
|
| 326 |
+
rope_dim=self.ROPE_DIM,
|
| 327 |
+
rms_norm_eps=self.RMS_NORM_EPS,
|
| 328 |
+
weight_scale=weight_scale,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
self.n_heads = next(iter(self.scorers.values())).n_heads
|
| 332 |
+
self.to(device)
|
| 333 |
+
|
| 334 |
+
# ββ construction helpers ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 335 |
+
|
| 336 |
+
@staticmethod
|
| 337 |
+
def _split_joint_state(
|
| 338 |
+
state: Dict[str, torch.Tensor],
|
| 339 |
+
layers: Optional[List[str]] = None,
|
| 340 |
+
) -> "OrderedDict[str, Dict[str, torch.Tensor]]":
|
| 341 |
+
"""Split a joint state-dict (keys ``retrievers.l{ID}.*``) into per-layer dicts."""
|
| 342 |
+
is_joint = any(k.startswith("retrievers.") for k in state.keys())
|
| 343 |
+
if not is_joint:
|
| 344 |
+
raise ValueError(
|
| 345 |
+
"State dict is not in joint 'retrievers.l{ID}.*' format. "
|
| 346 |
+
f"Got keys e.g. {list(state.keys())[:3]}"
|
| 347 |
+
)
|
| 348 |
+
found = sorted({k.split(".")[1] for k in state if k.startswith("retrievers.")})
|
| 349 |
+
use_layers = layers if layers is not None else found
|
| 350 |
+
out: "OrderedDict[str, Dict[str, torch.Tensor]]" = OrderedDict()
|
| 351 |
+
wanted = ("wq_a.weight", "wq_b.weight", "q_norm_weight", "weights_proj.weight")
|
| 352 |
+
for lname in use_layers:
|
| 353 |
+
prefix = f"retrievers.{lname}."
|
| 354 |
+
sub = {k[len(prefix):]: v for k, v in state.items() if k.startswith(prefix)}
|
| 355 |
+
if not sub:
|
| 356 |
+
raise ValueError(
|
| 357 |
+
f"Layer {lname!r} not found in checkpoint. Available: {found}"
|
| 358 |
+
)
|
| 359 |
+
missing = [w for w in wanted if w not in sub]
|
| 360 |
+
if missing:
|
| 361 |
+
raise ValueError(f"Layer {lname!r} missing weights {missing}")
|
| 362 |
+
out[lname] = {w: sub[w] for w in wanted}
|
| 363 |
+
return out
|
| 364 |
+
|
| 365 |
+
@classmethod
|
| 366 |
+
def from_checkpoint(
|
| 367 |
+
cls,
|
| 368 |
+
ckpt_path: str,
|
| 369 |
+
device: Union[str, torch.device] = "cpu",
|
| 370 |
+
max_position: int = 524288,
|
| 371 |
+
layers: Optional[List[str]] = None,
|
| 372 |
+
) -> "FlashMemoryRetriever":
|
| 373 |
+
"""Load a joint checkpoint and build the retriever.
|
| 374 |
+
|
| 375 |
+
Supports both ``.pt`` (``torch.save`` state-dict) and ``.safetensors``
|
| 376 |
+
(HuggingFace convention). Only the learned weights (``wq_a/wq_b/
|
| 377 |
+
q_norm_weight/weights_proj``) are read; the RoPE ``freqs_cis`` table is
|
| 378 |
+
recomputed locally, so a slim ``.safetensors`` loads identically.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
ckpt_path: path to the joint checkpoint (``.pt`` or ``.safetensors``).
|
| 382 |
+
device: device to load onto.
|
| 383 |
+
max_position: RoPE table length (see ``__init__``).
|
| 384 |
+
layers: optional subset of layer names (e.g. ``["l10", "l20"]``). If
|
| 385 |
+
None, all layers found in the checkpoint are used.
|
| 386 |
+
"""
|
| 387 |
+
if str(ckpt_path).endswith(".safetensors"):
|
| 388 |
+
from safetensors.torch import load_file
|
| 389 |
+
state = load_file(ckpt_path, device="cpu")
|
| 390 |
+
else:
|
| 391 |
+
state = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
| 392 |
+
layer_states = cls._split_joint_state(state, layers=layers)
|
| 393 |
+
return cls(layer_states, device=device, max_position=max_position)
|
| 394 |
+
|
| 395 |
+
# ββ inference ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 396 |
+
|
| 397 |
+
@torch.no_grad()
|
| 398 |
+
def forward(
|
| 399 |
+
self,
|
| 400 |
+
hidden_state: torch.Tensor, # [B, 4096]
|
| 401 |
+
compressed_k: torch.Tensor, # [B, N, head_dim + 4] uint8
|
| 402 |
+
positions: torch.Tensor, # [B] int64
|
| 403 |
+
apply_sigmoid: bool = True,
|
| 404 |
+
) -> "OrderedDict[str, torch.Tensor]":
|
| 405 |
+
"""Score the compressed-K chunks with every CSA layer.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
hidden_state: [B, 4096] decode-token hidden states.
|
| 409 |
+
compressed_k: [B, N, head_dim + 4] uint8 compressed keys (shared across
|
| 410 |
+
layers in this reference impl β see note below).
|
| 411 |
+
positions: [B] int64 token positions (for RoPE).
|
| 412 |
+
apply_sigmoid: if True (default) return sigmoid scores β [0, 1];
|
| 413 |
+
if False return raw logits.
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
OrderedDict ``{layer_name: scores [B, N]}``.
|
| 417 |
+
|
| 418 |
+
Note:
|
| 419 |
+
In the production DeepSeek-V4 CSA system each layer has its *own*
|
| 420 |
+
compressed-K buffer. This reference impl scores all layers against the
|
| 421 |
+
single ``compressed_k`` you pass, which is the right behavior for the
|
| 422 |
+
standalone algorithm demo. If you have per-layer K, call this once per
|
| 423 |
+
layer with that layer's K, or use ``score_layer``.
|
| 424 |
+
"""
|
| 425 |
+
device = self.freqs_cis.device
|
| 426 |
+
hidden_state = hidden_state.to(device)
|
| 427 |
+
compressed_k = compressed_k.to(device)
|
| 428 |
+
positions = positions.to(device)
|
| 429 |
+
|
| 430 |
+
k_float = dequant_compressed_k(compressed_k, head_dim=self.head_dim) # [B, N, D]
|
| 431 |
+
|
| 432 |
+
out: "OrderedDict[str, torch.Tensor]" = OrderedDict()
|
| 433 |
+
for name, scorer in self.scorers.items():
|
| 434 |
+
logits = scorer.logits(hidden_state, k_float, positions, self.freqs_cis)
|
| 435 |
+
out[name] = torch.sigmoid(logits) if apply_sigmoid else logits
|
| 436 |
+
return out
|
| 437 |
+
|
| 438 |
+
@torch.no_grad()
|
| 439 |
+
def score_layer(
|
| 440 |
+
self,
|
| 441 |
+
layer_name: str,
|
| 442 |
+
hidden_state: torch.Tensor,
|
| 443 |
+
compressed_k: torch.Tensor,
|
| 444 |
+
positions: torch.Tensor,
|
| 445 |
+
apply_sigmoid: bool = True,
|
| 446 |
+
) -> torch.Tensor:
|
| 447 |
+
"""Score a single layer (useful when each layer has its own K)."""
|
| 448 |
+
device = self.freqs_cis.device
|
| 449 |
+
k_float = dequant_compressed_k(compressed_k.to(device), head_dim=self.head_dim)
|
| 450 |
+
logits = self.scorers[layer_name].logits(
|
| 451 |
+
hidden_state.to(device), k_float, positions.to(device), self.freqs_cis
|
| 452 |
+
)
|
| 453 |
+
return torch.sigmoid(logits) if apply_sigmoid else logits
|
| 454 |
+
|
| 455 |
+
@torch.no_grad()
|
| 456 |
+
def ensemble(
|
| 457 |
+
self,
|
| 458 |
+
hidden_state: torch.Tensor,
|
| 459 |
+
compressed_k: torch.Tensor,
|
| 460 |
+
positions: torch.Tensor,
|
| 461 |
+
mode: str = "max",
|
| 462 |
+
) -> torch.Tensor:
|
| 463 |
+
"""Cross-layer ensemble of per-chunk sigmoid scores.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
mode: ``"max"`` (default) or ``"mean"`` over the per-layer sigmoid
|
| 467 |
+
scores, per chunk.
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
scores [B, N] β [0, 1].
|
| 471 |
+
"""
|
| 472 |
+
assert mode in ("max", "mean"), f"unknown ensemble mode: {mode!r}"
|
| 473 |
+
per_layer = self.forward(hidden_state, compressed_k, positions, apply_sigmoid=True)
|
| 474 |
+
stacked = torch.stack(list(per_layer.values()), dim=0) # [L, B, N]
|
| 475 |
+
if mode == "max":
|
| 476 |
+
return stacked.amax(dim=0)
|
| 477 |
+
return stacked.mean(dim=0)
|
| 478 |
+
|
| 479 |
+
@torch.no_grad()
|
| 480 |
+
def select_topk(
|
| 481 |
+
self,
|
| 482 |
+
hidden_state: torch.Tensor,
|
| 483 |
+
compressed_k: torch.Tensor,
|
| 484 |
+
positions: torch.Tensor,
|
| 485 |
+
top_k: Optional[int] = None,
|
| 486 |
+
threshold: Optional[float] = None,
|
| 487 |
+
mode: str = "max",
|
| 488 |
+
) -> torch.Tensor:
|
| 489 |
+
"""Return a boolean keep-mask [B, N] of selected chunks.
|
| 490 |
+
|
| 491 |
+
Exactly one of ``top_k`` / ``threshold`` should be given. With ``top_k``
|
| 492 |
+
the top-k highest-scoring chunks per row are kept; with ``threshold`` all
|
| 493 |
+
chunks whose ensembled sigmoid score exceeds the threshold are kept.
|
| 494 |
+
"""
|
| 495 |
+
scores = self.ensemble(hidden_state, compressed_k, positions, mode=mode) # [B, N]
|
| 496 |
+
B, N = scores.shape
|
| 497 |
+
if (top_k is None) == (threshold is None):
|
| 498 |
+
raise ValueError("Provide exactly one of top_k or threshold")
|
| 499 |
+
if threshold is not None:
|
| 500 |
+
return scores > threshold
|
| 501 |
+
k = min(top_k, N)
|
| 502 |
+
keep = torch.zeros(B, N, dtype=torch.bool, device=scores.device)
|
| 503 |
+
idx = scores.topk(k, dim=-1).indices
|
| 504 |
+
keep.scatter_(1, idx, True)
|
| 505 |
+
return keep
|
toy_flashmemory_inference.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
toy_flashmemory_inference.py β Toy sparse-decode loop driven by the FlashMemory Retriever
|
| 3 |
+
=========================================================================================
|
| 4 |
+
|
| 5 |
+
A minimal, torch-only illustration of how the FlashMemory Retriever controls CSA
|
| 6 |
+
memory recall during decode. Every 64 steps the retriever scores all N compressed-K
|
| 7 |
+
chunks against the current decode hidden state, selects the top-K (or thresholded)
|
| 8 |
+
ones to keep, and the rest are masked from attention β exactly as if their KV were
|
| 9 |
+
never recalled onto the GPU.
|
| 10 |
+
|
| 11 |
+
This is NOT a real DeepSeek-V4. The "decoder" is a few toy layers with random
|
| 12 |
+
weights. But the retriever, its scoring math, and the decode-time control flow
|
| 13 |
+
are all real.
|
| 14 |
+
|
| 15 |
+
Run::
|
| 16 |
+
|
| 17 |
+
python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import math
|
| 24 |
+
import os
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
|
| 31 |
+
# Ensure sibling retriever.py is importable (works from any cwd).
|
| 32 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 33 |
+
from retriever import FlashMemoryRetriever, dequant_compressed_k # noqa: E402
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
HIDDEN_DIM = 4096 # fixed: the retriever consumes a [B, 4096] decode hidden state
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
# Mock CSA KV-cache: N compressed chunks, each [head_dim + 4] uint8
|
| 41 |
+
# (this is the *indexer's* quantized-K representation that the retriever scores)
|
| 42 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
def make_mock_compressed_k(
|
| 44 |
+
batch: int,
|
| 45 |
+
n_chunks: int,
|
| 46 |
+
head_dim: int = 128,
|
| 47 |
+
device: str = "cpu",
|
| 48 |
+
seed: int = 0,
|
| 49 |
+
) -> torch.Tensor:
|
| 50 |
+
"""Build a valid mock ``compressed_k`` tensor ``[B, N, head_dim + 4]`` uint8.
|
| 51 |
+
|
| 52 |
+
This mirrors how the real CSA cache stores a compressed key per chunk:
|
| 53 |
+
bytes[:head_dim] β float8_e4m3 quantized key values (1 byte each)
|
| 54 |
+
bytes[head_dim:+4] β one float32 per-chunk dequant scale
|
| 55 |
+
|
| 56 |
+
In a real FlashMemory run these bytes are produced during *prefill*, when the
|
| 57 |
+
historical KV is compressed and stored. Here we just sample them randomly β
|
| 58 |
+
the retriever still runs its exact scoring path over them.
|
| 59 |
+
"""
|
| 60 |
+
g = torch.Generator(device=device).manual_seed(seed)
|
| 61 |
+
|
| 62 |
+
# 1) fp8 key bytes
|
| 63 |
+
k_vals = torch.randn(batch, n_chunks, head_dim, generator=g, device=device) * 0.5
|
| 64 |
+
fp8_bytes = k_vals.to(torch.float8_e4m3fn).view(torch.uint8) # [B, N, head_dim]
|
| 65 |
+
|
| 66 |
+
# 2) float32 per-chunk scale β 4 uint8 bytes
|
| 67 |
+
scale = (0.05 + 0.15 * torch.rand(batch, n_chunks, 1, generator=g, device=device)).float()
|
| 68 |
+
scale_bytes = scale.view(torch.uint8) # [B, N, 4]
|
| 69 |
+
|
| 70 |
+
compressed = torch.cat([fp8_bytes, scale_bytes], dim=-1) # [B, N, head_dim + 4]
|
| 71 |
+
assert compressed.shape[-1] == head_dim + 4
|
| 72 |
+
return compressed.contiguous()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 76 |
+
# Toy decoder (random weights). Only exists to emit a [B,4096] hidden state
|
| 77 |
+
# each step and own a memory cross-attention over N CSA chunks that the
|
| 78 |
+
# retriever's keep-mask sparsifies.
|
| 79 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
+
def _rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
| 81 |
+
norm = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + eps)
|
| 82 |
+
return (x.float() * norm).to(x.dtype) * weight
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ToyMemoryDecoder(nn.Module):
|
| 86 |
+
"""A few layers of toy memory cross-attention + MLP (random weights)."""
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
n_chunks: int,
|
| 91 |
+
n_layers: int = 2,
|
| 92 |
+
n_heads: int = 8,
|
| 93 |
+
vocab_size: int = 512,
|
| 94 |
+
device: str = "cpu",
|
| 95 |
+
seed: int = 0,
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
torch.manual_seed(seed)
|
| 99 |
+
self.hidden_dim = HIDDEN_DIM
|
| 100 |
+
self.n_layers = n_layers
|
| 101 |
+
self.n_heads = n_heads
|
| 102 |
+
self.head_dim = self.hidden_dim // n_heads
|
| 103 |
+
self.n_chunks = n_chunks
|
| 104 |
+
|
| 105 |
+
# Token embedding (toy; vocab is meaningless).
|
| 106 |
+
self.embed = nn.Embedding(vocab_size, self.hidden_dim)
|
| 107 |
+
|
| 108 |
+
# Decoder-space memory bank: one vector per CSA chunk (separate from the
|
| 109 |
+
# retriever's compressed_k β both index the same N chunks).
|
| 110 |
+
self.register_buffer("memory", torch.randn(n_chunks, self.hidden_dim) * 0.02)
|
| 111 |
+
|
| 112 |
+
# Per-layer projections + norms.
|
| 113 |
+
self.wq = nn.ModuleList(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False) for _ in range(n_layers))
|
| 114 |
+
self.wk = nn.ModuleList(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False) for _ in range(n_layers))
|
| 115 |
+
self.wv = nn.ModuleList(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False) for _ in range(n_layers))
|
| 116 |
+
self.wo = nn.ModuleList(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False) for _ in range(n_layers))
|
| 117 |
+
self.mlp_up = nn.ModuleList(nn.Linear(self.hidden_dim, 2 * self.hidden_dim, bias=False) for _ in range(n_layers))
|
| 118 |
+
self.mlp_down = nn.ModuleList(nn.Linear(2 * self.hidden_dim, self.hidden_dim, bias=False) for _ in range(n_layers))
|
| 119 |
+
self.attn_norm = nn.ParameterList(nn.Parameter(torch.ones(self.hidden_dim)) for _ in range(n_layers))
|
| 120 |
+
self.mlp_norm = nn.ParameterList(nn.Parameter(torch.ones(self.hidden_dim)) for _ in range(n_layers))
|
| 121 |
+
self.final_norm = nn.Parameter(torch.ones(self.hidden_dim))
|
| 122 |
+
self.lm_head = nn.Linear(self.hidden_dim, vocab_size, bias=False)
|
| 123 |
+
|
| 124 |
+
self.to(device)
|
| 125 |
+
self.eval()
|
| 126 |
+
|
| 127 |
+
@torch.no_grad()
|
| 128 |
+
def _memory_attention(self, x: torch.Tensor, layer: int, keep_mask: torch.Tensor | None) -> torch.Tensor:
|
| 129 |
+
"""Cross-attention of the current token(s) over the N memory chunks.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
x: [B, hidden] current-token hidden state(s).
|
| 133 |
+
keep_mask: [B, N] bool, True = chunk recalled/kept. ``None`` = keep all
|
| 134 |
+
(the dense path used during prefill / cold-start).
|
| 135 |
+
|
| 136 |
+
Chunks with ``keep_mask == False`` get their attention logit set to
|
| 137 |
+
``-inf`` β softmax weight 0 β they contribute nothing. THIS is our
|
| 138 |
+
simulation of "the chunk was not recalled onto the GPU".
|
| 139 |
+
"""
|
| 140 |
+
B = x.shape[0]
|
| 141 |
+
H, D = self.n_heads, self.head_dim
|
| 142 |
+
|
| 143 |
+
q = self.wq[layer](x).view(B, H, 1, D) # [B, H, 1, D]
|
| 144 |
+
k = self.wk[layer](self.memory).view(self.n_chunks, H, D).permute(1, 0, 2) # [H, N, D]
|
| 145 |
+
v = self.wv[layer](self.memory).view(self.n_chunks, H, D).permute(1, 0, 2) # [H, N, D]
|
| 146 |
+
|
| 147 |
+
# [B, H, 1, N] attention logits over the N memory chunks.
|
| 148 |
+
logits = torch.einsum("bhqd,hnd->bhqn", q, k) / math.sqrt(D)
|
| 149 |
+
if keep_mask is not None:
|
| 150 |
+
# Broadcast [B, N] β [B, 1, 1, N] and mask the dropped chunks.
|
| 151 |
+
drop = ~keep_mask.view(B, 1, 1, self.n_chunks)
|
| 152 |
+
logits = logits.masked_fill(drop, float("-inf"))
|
| 153 |
+
|
| 154 |
+
attn = torch.softmax(logits, dim=-1) # [B, H, 1, N]
|
| 155 |
+
out = torch.einsum("bhqn,hnd->bhqd", attn, v).reshape(B, self.hidden_dim)
|
| 156 |
+
return self.wo[layer](out)
|
| 157 |
+
|
| 158 |
+
@torch.no_grad()
|
| 159 |
+
def step(
|
| 160 |
+
self,
|
| 161 |
+
token_ids: torch.Tensor, # [B] int64
|
| 162 |
+
keep_mask: torch.Tensor | None, # [B, N] bool, or None for dense
|
| 163 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 164 |
+
"""One decode step. Returns (hidden [B, 4096], next-token logits [B, vocab])."""
|
| 165 |
+
x = self.embed(token_ids) # [B, hidden]
|
| 166 |
+
for layer in range(self.n_layers):
|
| 167 |
+
x = x + self._memory_attention(_rmsnorm(x, self.attn_norm[layer]), layer, keep_mask)
|
| 168 |
+
h = _rmsnorm(x, self.mlp_norm[layer])
|
| 169 |
+
x = x + self.mlp_down[layer](F.gelu(self.mlp_up[layer](h)))
|
| 170 |
+
hidden = _rmsnorm(x, self.final_norm) # [B, 4096] β feeds retriever
|
| 171 |
+
return hidden, self.lm_head(hidden)
|
| 172 |
+
|
| 173 |
+
@torch.no_grad()
|
| 174 |
+
def prefill(self, prefill_ids: torch.Tensor) -> torch.Tensor:
|
| 175 |
+
"""Toy 'prefill': run a short prompt through DENSE memory attention.
|
| 176 |
+
|
| 177 |
+
Returns the last token's hidden state, which seeds the very first
|
| 178 |
+
retrieval cycle (the indexer needs a query hidden state to score against).
|
| 179 |
+
Prefill is intentionally dense (keep_mask=None): the model sees the whole
|
| 180 |
+
history before decoding begins.
|
| 181 |
+
"""
|
| 182 |
+
hidden = None
|
| 183 |
+
for t in range(prefill_ids.shape[1]):
|
| 184 |
+
hidden, _ = self.step(prefill_ids[:, t], keep_mask=None)
|
| 185 |
+
return hidden # [B, 4096]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 189 |
+
# Retrieval helper: scores β keep-mask (top-K or threshold)
|
| 190 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 191 |
+
def scores_to_keep_mask(
|
| 192 |
+
scores: torch.Tensor, # [B, N] sigmoid scores β [0, 1]
|
| 193 |
+
select_mode: str,
|
| 194 |
+
top_k: int,
|
| 195 |
+
threshold: float,
|
| 196 |
+
) -> torch.Tensor:
|
| 197 |
+
"""Turn per-chunk retriever scores into a boolean keep-mask [B, N]."""
|
| 198 |
+
B, N = scores.shape
|
| 199 |
+
if select_mode == "topk":
|
| 200 |
+
k = min(top_k, N)
|
| 201 |
+
keep = torch.zeros(B, N, dtype=torch.bool, device=scores.device)
|
| 202 |
+
idx = scores.topk(k, dim=-1).indices
|
| 203 |
+
keep.scatter_(1, idx, True)
|
| 204 |
+
return keep
|
| 205 |
+
elif select_mode == "threshold":
|
| 206 |
+
return scores > threshold
|
| 207 |
+
raise ValueError(f"unknown select_mode: {select_mode!r}")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 211 |
+
# main
|
| 212 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 213 |
+
def main():
|
| 214 |
+
ap = argparse.ArgumentParser(
|
| 215 |
+
description="Toy DeepSeek-V4-FlashMemory sparse-decode loop driven by the FlashMemory Retriever"
|
| 216 |
+
)
|
| 217 |
+
ap.add_argument("--ckpt", required=True, help="path to the FlashMemory DS-V4 joint checkpoint (.pt)")
|
| 218 |
+
ap.add_argument("--device", default="cpu", help="cpu or cuda (default: cpu)")
|
| 219 |
+
ap.add_argument("--batch", type=int, default=1, help="number of parallel decode sequences")
|
| 220 |
+
ap.add_argument("--n-chunks", type=int, default=256, help="number of CSA memory chunks (the long history)")
|
| 221 |
+
ap.add_argument("--steps", type=int, default=192, help="number of decode steps to generate")
|
| 222 |
+
ap.add_argument("--retrieval-interval", type=int, default=64,
|
| 223 |
+
help="run the retriever every N decode steps (FlashMemory default 64)")
|
| 224 |
+
ap.add_argument("--select-mode", default="topk", choices=["topk", "threshold"],
|
| 225 |
+
help="how to turn scores into a keep-mask")
|
| 226 |
+
ap.add_argument("--top-k", type=int, default=64, help="chunks to recall per cycle (select-mode=topk)")
|
| 227 |
+
ap.add_argument("--threshold", type=float, default=0.5, help="sigmoid keep threshold (select-mode=threshold)")
|
| 228 |
+
ap.add_argument("--ensemble", default="max", choices=["max", "mean"], help="cross-layer ensemble mode")
|
| 229 |
+
ap.add_argument("--max-position", type=int, default=524288, help="RoPE table length")
|
| 230 |
+
ap.add_argument("--n-layers", type=int, default=2, help="toy decoder layers")
|
| 231 |
+
ap.add_argument("--seed", type=int, default=0)
|
| 232 |
+
args = ap.parse_args()
|
| 233 |
+
|
| 234 |
+
torch.manual_seed(args.seed)
|
| 235 |
+
device = args.device
|
| 236 |
+
B, N = args.batch, args.n_chunks
|
| 237 |
+
|
| 238 |
+
# ββ 1. Load retriever ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 239 |
+
print(f"FlashMemory DS-V4 β toy sparse-decode loop")
|
| 240 |
+
print(f"[load] {args.ckpt}")
|
| 241 |
+
retriever = FlashMemoryRetriever.from_checkpoint(
|
| 242 |
+
args.ckpt, device=device, max_position=args.max_position
|
| 243 |
+
)
|
| 244 |
+
retriever.eval()
|
| 245 |
+
print(f"[load] layers={retriever.layer_names} n_heads={retriever.n_heads} "
|
| 246 |
+
f"head_dim={retriever.head_dim}")
|
| 247 |
+
|
| 248 |
+
# ββ 2. Build toy decoder + mock CSA memory βββββββββββββββββββββββββββββββββ
|
| 249 |
+
decoder = ToyMemoryDecoder(n_chunks=N, n_layers=args.n_layers, device=device, seed=args.seed)
|
| 250 |
+
compressed_k = make_mock_compressed_k(B, N, head_dim=retriever.head_dim,
|
| 251 |
+
device=device, seed=args.seed)
|
| 252 |
+
print(f"[init] decoder: {args.n_layers} layers, {decoder.n_heads} heads | "
|
| 253 |
+
f"CSA memory: {N} chunks [{retriever.head_dim + 4}] uint8")
|
| 254 |
+
|
| 255 |
+
# ββ 3. Prefill βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 256 |
+
prefill_len = 8
|
| 257 |
+
prefill_ids = torch.randint(0, 512, (B, prefill_len), device=device)
|
| 258 |
+
last_hidden = decoder.prefill(prefill_ids)
|
| 259 |
+
base_pos = prefill_len
|
| 260 |
+
last_pos = torch.full((B,), prefill_len - 1, dtype=torch.int64, device=device)
|
| 261 |
+
|
| 262 |
+
sel_desc = (f"top-K={args.top_k}" if args.select_mode == "topk"
|
| 263 |
+
else f"sigmoid>{args.threshold}")
|
| 264 |
+
print(f"\n[decode] {args.steps} steps, retriever every {args.retrieval_interval} steps "
|
| 265 |
+
f"({args.select_mode} [{sel_desc}], ensemble={args.ensemble})")
|
| 266 |
+
print("-" * 60)
|
| 267 |
+
|
| 268 |
+
# ββ 4. Decode loop ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 269 |
+
keep_mask = None
|
| 270 |
+
token = decoder.embed.weight.new_zeros(B, dtype=torch.int64)
|
| 271 |
+
keep_ratios: list[float] = []
|
| 272 |
+
cycle = 0
|
| 273 |
+
|
| 274 |
+
for t in range(args.steps):
|
| 275 |
+
abs_pos = base_pos + t
|
| 276 |
+
|
| 277 |
+
if t % args.retrieval_interval == 0:
|
| 278 |
+
scores = retriever.ensemble(last_hidden, compressed_k, last_pos, mode=args.ensemble)
|
| 279 |
+
keep_mask = scores_to_keep_mask(scores, args.select_mode, args.top_k, args.threshold)
|
| 280 |
+
|
| 281 |
+
n_keep = keep_mask.sum(-1)
|
| 282 |
+
ratio = (n_keep.float() / N)
|
| 283 |
+
keep_ratios.extend(ratio.tolist())
|
| 284 |
+
w_lo = abs_pos
|
| 285 |
+
w_hi = min(abs_pos + args.retrieval_interval, base_pos + args.steps) - 1
|
| 286 |
+
|
| 287 |
+
print(f"[cycle {cycle:>2}] pos {w_lo:>5}..{w_hi:<5} | "
|
| 288 |
+
f"keep {fmt_ratio(ratio, B)} ({int(n_keep[0])}/{N}) | "
|
| 289 |
+
f"score mean={scores.mean():.4f} max={scores.max():.4f}")
|
| 290 |
+
cycle += 1
|
| 291 |
+
|
| 292 |
+
hidden, logits = decoder.step(token, keep_mask)
|
| 293 |
+
token = logits.argmax(-1)
|
| 294 |
+
last_hidden = hidden
|
| 295 |
+
last_pos = torch.full((B,), abs_pos, dtype=torch.int64, device=device)
|
| 296 |
+
|
| 297 |
+
# ββ 5. Summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 298 |
+
avg_keep = sum(keep_ratios) / max(len(keep_ratios), 1)
|
| 299 |
+
print("-" * 60)
|
| 300 |
+
print(f"[done] {args.steps} tokens, {cycle} cycles, "
|
| 301 |
+
f"avg keep/cycle: {avg_keep:.1%} β ~{1 - avg_keep:.0%} CSA KV dropped")
|
| 302 |
+
print(f"[note] Dropped chunks are masked to -inf in attention (= KV not recalled to GPU). "
|
| 303 |
+
f"Production swap engine not included in this release.")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def fmt_ratio(t: torch.Tensor, B: int) -> str:
|
| 307 |
+
vals = t.tolist()
|
| 308 |
+
return f"{vals[0]:.1%}" if B == 1 else "[" + ", ".join(f"{v:.1%}" for v in vals) + "]"
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
if __name__ == "__main__":
|
| 312 |
+
main()
|
weights/flashmemory_ds_v4.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba20d264c309246f824d4471ccc637061b3b0268fe8e4eecc121474a1e5cd02a
|
| 3 |
+
size 509633992
|