libertywing's picture
Streamline model card
9437df5
---
language: en
license: mit
tags:
- deepseek-v4
- retrieval
- kv-cache
- sparse-attention
- long-context
- flashmemory
datasets:
- ruler
- longmemeval
- longbench-v2
- mrcr
---
# FlashMemory DS-V4 Retriever
A lightweight retriever that sparsifies **DeepSeek-V4 CSA KV-cache**. Given a
decode-token hidden state, it predicts which compressed-K chunks the next
~64 tokens will attend to β€” keeping only those on GPU, offloading the rest.
In downstream evaluation it matches or beats full-attention baseline on
reasoning-heavy long-context tasks (**RULER, LongMemEval, LongBench V2**)
while reducing KV-cache usage by **~85–90%**. Precise needle-retrieval tasks
require an additional threshold-fallback mechanism (not in this release).
## Quick start
```bash
pip install torch safetensors
python demo.py --ckpt weights/flashmemory_ds_v4.safetensors
```
## Usage
```python
from retriever import FlashMemoryRetriever
model = FlashMemoryRetriever.from_checkpoint(
"weights/flashmemory_ds_v4.safetensors", device="cuda"
)
# hidden: [B, 4096] decode hidden state
# compressed_k: [B, N, 132] uint8 CSA keys
# positions: [B] int64 token positions
scores = model.ensemble(hidden, compressed_k, positions, mode="max") # [B, N]
keep = model.select_topk(hidden, compressed_k, positions, top_k=512) # boolean mask
```
**`compressed_k` format:** each chunk = 128 bytes `float8_e4m3` values + 4 bytes `float32` scale. See `make_mock_compressed_k()` in `demo.py`.
## Architecture
3-layer joint model (`l10`, `l12`, `l20`), 128 heads, 2048 LoRA rank. Per-layer
sigmoid scores are ensembled (`max` or `mean`) per chunk.
```
hidden [B,4096] β†’ q-proj β†’ RoPE(YaRN) β†’ Hadamard β†’ q [B,128,128]
β†’ weights_proj β†’ fused_w [B,128]
compressed_k β†’ FP8 dequant β†’ k [B,N,128]
score = sigmoid( Ξ£( relu(k @ qα΅€) Β· fused_w ) ) ∈ [0,1]
```
## Toy inference reference
`toy_flashmemory_inference.py` illustrates how the retriever drives memory
recall during decode: every 64 steps it re-scores all chunks, and unselected
ones are masked from attention (equivalent to "not recalled to GPU").
```bash
python toy_flashmemory_inference.py --ckpt weights/flashmemory_ds_v4.safetensors
```
> The decoder is a few toy layers with random weights β€” it is **not** a real
> DeepSeek-V4. The retriever, scoring math, and decode-time control flow are real.
## Files
| File | Purpose |
|------|---------|
| `retriever.py` | `FlashMemoryRetriever` model (torch-only, self-contained) |
| `demo.py` | minimal demo with mock inputs |
| `toy_flashmemory_inference.py` | toy sparse-decode loop |
| `weights/flashmemory_ds_v4.safetensors` | trained weights (~510 MB) |
## License
MIT