kv-cache-eviction-mla / docs /HOW_IT_WORKS.md
GENOMA LABS / research
initial release: H2O KV cache eviction for DeepseekV3 / MLA architectures
a8d4591
# How It Works: Architectural Notes
This document explains how the eviction patch hooks into a `DeepseekV3Attention` layer and where the design choices live. Read this before modifying the patch or porting it to a different attention class.
## The patch surface
`install_kv_eviction(model, ...)` walks `model.modules()` and finds every `DeepseekV3Attention` instance. For each one it:
1. **Stashes the original `forward` method** as an attribute on the module (so `remove_kv_eviction` can restore it later).
2. **Creates an `_EvictionState` object** holding `budget`, `n_sink`, `n_recent`, `evict_every`, and a per-layer accumulated-score tensor `score`.
3. **Replaces the layer's `forward`** with a closure that wraps the original. The closure runs the original forward, captures the attention probabilities, accumulates them into `state.score`, then calls `_maybe_evict` on the cache.
The original forward path is not modified. We sit *outside* the math, observing the attention probabilities and managing the cache as a side effect.
## What gets evicted
Eviction operates on the layer's KV cache. With `transformers.cache_utils.DynamicCache` (the default for HuggingFace generation), each layer has a `key_cache[layer_idx]` and `value_cache[layer_idx]` tensor of shape `[batch, heads, seq, dim]`. We slice along the `seq` dimension, keeping:
- **Sinks:** indices `[0, n_sink)` always.
- **Recent:** indices `[seq - n_recent, seq)` always.
- **Heavy hitters:** the top `budget` indices by accumulated attention mass, drawn from the middle range `[n_sink, seq - n_recent)`.
The middle range is where eviction actually does work; the sink and recent ranges are non-evictable.
## Score accumulation
Per-token attention mass is accumulated as:
```
state.score[batch, token] += sum_over_heads(attn_probs[batch, head, *, token])
```
That is: for each token in the cache, sum the attention probability that all queries at this step paid to it, summed across all heads. The accumulation runs every step; over time, tokens that are repeatedly attended to by many heads accumulate large scores; tokens that are attended to once or twice fade.
**Cross-head, not per-head.** This is a defensible default but it has a downside: heads that specialize (e.g., one head dedicated to attention sinks, another to retrieval) get averaged together. Per-head policies are sometimes superior in production; this implementation does not currently expose that knob. PRs welcome.
## When eviction triggers
`evict_every=N` controls how often `_maybe_evict` actually runs. Default is 1 (every step). Larger values reduce per-step overhead but increase peak cache size between evictions:
```
peak_cache_size <= n_sink + budget + n_recent + (evict_every - 1)
```
For most workloads `evict_every=1` is fine. If the eviction step shows up in profiling as a measurable cost (rare, since `argpartition` on a few thousand floats is fast), increase to 8 or 16.
## Edge cases
A few cases the code handles explicitly:
- **Cache below threshold.** If `len(cache) <= n_sink + budget + n_recent`, `_maybe_evict` is a no-op. No eviction happens until the cache grows past the budget.
- **First call.** The accumulated score tensor is initialized lazily on the first forward pass once we know batch size and current cache length.
- **Cache reset between generations.** `reset_eviction_scores(model)` zeroes all accumulated scores. Call this between independent generations; otherwise the previous generation's heavy hitters bias the next one.
- **Model on multiple devices.** The score tensor lives on the same device as the cache. With `device_map="auto"` model sharding, each layer's score lives on its layer's device; nothing special to do.
## Failure modes to watch for
If eviction breaks, you'll typically see one of these:
- **Repetitive garbage output** -> sinks were evicted. Verify `n_sink >= 4` and that index 0 of `key_cache` is preserved across eviction calls.
- **Topical drift / stale context** -> recent window too small. Increase `n_recent`.
- **Sudden quality cliff at long context** -> heavy-hitter score accumulation has saturated or normalized incorrectly. Check that `state.score` is being updated every step (not just on eviction-trigger steps).
- **Memory growing despite eviction** -> the cache class is not `DynamicCache` and the patch's slicing assumptions don't hold. Print `type(past_key_value)` from inside the patched forward to verify.
## Porting to a different attention class
To adapt this patch for non-MLA attention (Llama, Qwen, Mistral, Gemma standard MHA / GQA):
1. Find the attention class in the relevant `transformers/models/<arch>/modeling_<arch>.py`.
2. Identify how it stores K/V in the cache. Most use `DynamicCache` with `[batch, heads, seq, head_dim]`. MLA stores expanded K/V at `[batch, heads, seq, qk_dim]` and `[batch, heads, seq, v_dim]` separately because qk_dim != v_dim. For uniform-dim attention, the slicing in `_maybe_evict` simplifies.
3. Replace the `DeepseekV3Attention` class lookup in `install_kv_eviction` with the relevant class.
4. Verify the attention probability tensor shape returned by the original forward; the score-accumulation code assumes `[batch, heads, q_len, k_len]`.
The H2O recipe itself does not need to change; only the cache management code does.
## Testing methodology
The smoke test in `src/kv_eviction_mla.py` exercises the eviction-decision logic on a `_FakeCache` that mirrors `DynamicCache`'s structure but skips the model. Real-model validation requires a GPU and a checkpoint. We recommend:
1. Run the smoke test (no GPU required) to verify the patch logic.
2. Patch a small DeepSeek/Kimi model (e.g., a 1B or 7B variant) and run a perplexity-sweep notebook on a public benchmark prompt set.
3. Run RULER 128K NIAH at your target budget. Compare to full-cache baseline.
4. Stress-test on your actual workload distribution.
A canonical RULER benchmark with this code is in the project roadmap.