File size: 2,712 Bytes
9dcfd27 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 640b654 9437df5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | ---
language: en
license: mit
tags:
- deepseek-v4
- retrieval
- kv-cache
- sparse-attention
- long-context
- flashmemory
datasets:
- ruler
- longmemeval
- longbench-v2
- mrcr
---
# FlashMemory DS-V4 Retriever
A lightweight retriever that sparsifies **DeepSeek-V4 CSA KV-cache**. Given a
decode-token hidden state, it predicts which compressed-K chunks the next
~64 tokens will attend to β keeping only those on GPU, offloading the rest.
In downstream evaluation it matches or beats full-attention baseline on
reasoning-heavy long-context tasks (**RULER, LongMemEval, LongBench V2**)
while reducing KV-cache usage by **~85β90%**. Precise needle-retrieval tasks
require an additional threshold-fallback mechanism (not in this release).
## Quick start
```bash
pip install torch safetensors
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
```
## Usage
```python
from retriever import FlashMemoryRetriever
model = FlashMemoryRetriever.from_checkpoint(
"weights/flashmemory_ds_v4.safetensors", device="cuda"
)
# hidden: [B, 4096] decode hidden state
# compressed_k: [B, N, 132] uint8 CSA keys
# positions: [B] int64 token positions
scores = model.ensemble(hidden, compressed_k, positions, mode="max") # [B, N]
keep = model.select_topk(hidden, compressed_k, positions, top_k=512) # boolean mask
```
**`compressed_k` format:** each chunk = 128 bytes `float8_e4m3` values + 4 bytes `float32` scale. See `make_mock_compressed_k()` in `demo.py`.
## Architecture
3-layer joint model (`l10`, `l12`, `l20`), 128 heads, 2048 LoRA rank. Per-layer
sigmoid scores are ensembled (`max` or `mean`) per chunk.
```
hidden [B,4096] β q-proj β RoPE(YaRN) β Hadamard β q [B,128,128]
β weights_proj β fused_w [B,128]
compressed_k β FP8 dequant β k [B,N,128]
score = sigmoid( Ξ£( relu(k @ qα΅) Β· fused_w ) ) β [0,1]
```
## Toy inference reference
`toy_flashmemory_inference.py` illustrates how the retriever drives memory
recall during decode: every 64 steps it re-scores all chunks, and unselected
ones are masked from attention (equivalent to "not recalled to GPU").
```bash
python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
```
> The decoder is a few toy layers with random weights β it is **not** a real
> DeepSeek-V4. The retriever, scoring math, and decode-time control flow are real.
## Files
| File | Purpose |
|------|---------|
| `retriever.py` | `FlashMemoryRetriever` model (torch-only, self-contained) |
| `demo.py` | minimal demo with mock inputs |
| `toy_flashmemory_inference.py` | toy sparse-decode loop |
| `weights/flashmemory_ds_v4.safetensors` | trained weights (~510 MB) |
## License
MIT
|