kv-cache-eviction-mla / docs /HOW_IT_WORKS.md
GENOMA LABS / research
initial release: H2O KV cache eviction for DeepseekV3 / MLA architectures
a8d4591

How It Works: Architectural Notes

This document explains how the eviction patch hooks into a DeepseekV3Attention layer and where the design choices live. Read this before modifying the patch or porting it to a different attention class.

The patch surface

install_kv_eviction(model, ...) walks model.modules() and finds every DeepseekV3Attention instance. For each one it:

  1. Stashes the original forward method as an attribute on the module (so remove_kv_eviction can restore it later).
  2. Creates an _EvictionState object holding budget, n_sink, n_recent, evict_every, and a per-layer accumulated-score tensor score.
  3. Replaces the layer's forward with a closure that wraps the original. The closure runs the original forward, captures the attention probabilities, accumulates them into state.score, then calls _maybe_evict on the cache.

The original forward path is not modified. We sit outside the math, observing the attention probabilities and managing the cache as a side effect.

What gets evicted

Eviction operates on the layer's KV cache. With transformers.cache_utils.DynamicCache (the default for HuggingFace generation), each layer has a key_cache[layer_idx] and value_cache[layer_idx] tensor of shape [batch, heads, seq, dim]. We slice along the seq dimension, keeping:

  • Sinks: indices [0, n_sink) always.
  • Recent: indices [seq - n_recent, seq) always.
  • Heavy hitters: the top budget indices by accumulated attention mass, drawn from the middle range [n_sink, seq - n_recent).

The middle range is where eviction actually does work; the sink and recent ranges are non-evictable.

Score accumulation

Per-token attention mass is accumulated as:

state.score[batch, token] += sum_over_heads(attn_probs[batch, head, *, token])

That is: for each token in the cache, sum the attention probability that all queries at this step paid to it, summed across all heads. The accumulation runs every step; over time, tokens that are repeatedly attended to by many heads accumulate large scores; tokens that are attended to once or twice fade.

Cross-head, not per-head. This is a defensible default but it has a downside: heads that specialize (e.g., one head dedicated to attention sinks, another to retrieval) get averaged together. Per-head policies are sometimes superior in production; this implementation does not currently expose that knob. PRs welcome.

When eviction triggers

evict_every=N controls how often _maybe_evict actually runs. Default is 1 (every step). Larger values reduce per-step overhead but increase peak cache size between evictions:

peak_cache_size <= n_sink + budget + n_recent + (evict_every - 1)

For most workloads evict_every=1 is fine. If the eviction step shows up in profiling as a measurable cost (rare, since argpartition on a few thousand floats is fast), increase to 8 or 16.

Edge cases

A few cases the code handles explicitly:

  • Cache below threshold. If len(cache) <= n_sink + budget + n_recent, _maybe_evict is a no-op. No eviction happens until the cache grows past the budget.
  • First call. The accumulated score tensor is initialized lazily on the first forward pass once we know batch size and current cache length.
  • Cache reset between generations. reset_eviction_scores(model) zeroes all accumulated scores. Call this between independent generations; otherwise the previous generation's heavy hitters bias the next one.
  • Model on multiple devices. The score tensor lives on the same device as the cache. With device_map="auto" model sharding, each layer's score lives on its layer's device; nothing special to do.

Failure modes to watch for

If eviction breaks, you'll typically see one of these:

  • Repetitive garbage output -> sinks were evicted. Verify n_sink >= 4 and that index 0 of key_cache is preserved across eviction calls.
  • Topical drift / stale context -> recent window too small. Increase n_recent.
  • Sudden quality cliff at long context -> heavy-hitter score accumulation has saturated or normalized incorrectly. Check that state.score is being updated every step (not just on eviction-trigger steps).
  • Memory growing despite eviction -> the cache class is not DynamicCache and the patch's slicing assumptions don't hold. Print type(past_key_value) from inside the patched forward to verify.

Porting to a different attention class

To adapt this patch for non-MLA attention (Llama, Qwen, Mistral, Gemma standard MHA / GQA):

  1. Find the attention class in the relevant transformers/models/<arch>/modeling_<arch>.py.
  2. Identify how it stores K/V in the cache. Most use DynamicCache with [batch, heads, seq, head_dim]. MLA stores expanded K/V at [batch, heads, seq, qk_dim] and [batch, heads, seq, v_dim] separately because qk_dim != v_dim. For uniform-dim attention, the slicing in _maybe_evict simplifies.
  3. Replace the DeepseekV3Attention class lookup in install_kv_eviction with the relevant class.
  4. Verify the attention probability tensor shape returned by the original forward; the score-accumulation code assumes [batch, heads, q_len, k_len].

The H2O recipe itself does not need to change; only the cache management code does.

Testing methodology

The smoke test in src/kv_eviction_mla.py exercises the eviction-decision logic on a _FakeCache that mirrors DynamicCache's structure but skips the model. Real-model validation requires a GPU and a checkpoint. We recommend:

  1. Run the smoke test (no GPU required) to verify the patch logic.
  2. Patch a small DeepSeek/Kimi model (e.g., a 1B or 7B variant) and run a perplexity-sweep notebook on a public benchmark prompt set.
  3. Run RULER 128K NIAH at your target budget. Compare to full-cache baseline.
  4. Stress-test on your actual workload distribution.

A canonical RULER benchmark with this code is in the project roadmap.