Instructions to use GenomaLabs-com/kv-cache-eviction-mla with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use GenomaLabs-com/kv-cache-eviction-mla with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("GenomaLabs-com/kv-cache-eviction-mla", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 8,096 Bytes
1ba26d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | """B1: comprehensive multi-step validation of the H2O eviction logic on a mock
KV cache that mirrors the canonical DeepseekV3 cache structure.
This script extends the single-shot smoke test in src/kv_eviction_mla.py to
thousands of simulated generation steps, captures per-step cache size, and
emits a CSV showing the eviction mechanism stabilizes the cache at the
expected bound.
Why a mock cache instead of a full model:
The patch in src/kv_eviction_mla.py is currently aligned with the
transformers 4.x KV cache API (DynamicCache.key_cache / value_cache lists).
transformers 5.x reorganized the cache into DynamicCache.layers[i], so the
patch needs porting before it runs end-to-end on transformers 5.x. The
eviction *logic* is unchanged across transformers versions; the API plumbing
is what differs. This script validates the logic; the plumbing port is on
the roadmap.
Run:
python scripts/validate_eviction_random_init.py \
--steps 1000 --budget 64 --n-sink 4 --n-recent 16 \
--out-csv results/validate_eviction_random_init.csv
Expected output:
- For steps 1..(n_sink + budget + n_recent): cache grows linearly.
- For later steps: cache size stays at exactly (n_sink + budget + n_recent).
- Eviction events are logged each time the cache crosses the threshold.
"""
from __future__ import annotations
import argparse
import csv
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
import torch
# Import the eviction state class and the eviction function from the module
# under test.
from kv_eviction_mla import _EvictionState, _maybe_evict
class MockMLAcache:
"""Stand-in for transformers DynamicCache that exposes the same shape
contract the eviction code uses (key_cache[i], value_cache[i] slicable
along the seq dimension).
For DeepseekV3 / MLA, qk_dim != v_dim, so K and V have different head
dimensions. We mirror that here.
"""
def __init__(self, num_layers: int, batch: int, heads: int, qk_dim: int, v_dim: int, device: str = "cpu"):
self.key_cache = [
torch.zeros(batch, heads, 0, qk_dim, device=device) for _ in range(num_layers)
]
self.value_cache = [
torch.zeros(batch, heads, 0, v_dim, device=device) for _ in range(num_layers)
]
self._batch = batch
self._heads = heads
self._qk = qk_dim
self._v = v_dim
self._device = device
def append_token(self, layer_idx: int) -> None:
"""Simulate one new token landing in this layer's cache."""
new_k = torch.randn(self._batch, self._heads, 1, self._qk, device=self._device)
new_v = torch.randn(self._batch, self._heads, 1, self._v, device=self._device)
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], new_k], dim=2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], new_v], dim=2)
def seq_len(self, layer_idx: int) -> int:
return int(self.key_cache[layer_idx].shape[2])
def run_validation(
steps: int = 1000,
budget: int = 64,
n_sink: int = 4,
n_recent: int = 16,
num_layers: int = 4,
out_csv: Path = Path("results/validate_eviction_random_init.csv"),
) -> None:
print(f"[B1] mock cache: {num_layers} layers, canonical DeepseekV3 dims (qk=192, v=128)")
cache = MockMLAcache(num_layers=num_layers, batch=1, heads=4, qk_dim=192, v_dim=128)
# One eviction state per layer, as install_kv_eviction would create
states = [
_EvictionState(budget=budget, n_sink=n_sink, n_recent=n_recent, evict_every=1)
for _ in range(num_layers)
]
expected_cap = n_sink + budget + n_recent
print(f"[B1] eviction config: budget={budget} n_sink={n_sink} n_recent={n_recent}")
print(f"[B1] expected cache cap per layer: {expected_cap} tokens")
print(f"[B1] running {steps} simulated generation steps...")
print()
out_csv.parent.mkdir(parents=True, exist_ok=True)
rows = []
eviction_events = 0
t_start = time.time()
for step in range(steps):
# Simulate one generation step: each layer gets one new token appended.
for layer_idx in range(num_layers):
cache.append_token(layer_idx)
# Update accumulated importance scores. In the real patch, these
# come from attn_weights at every forward call. We synthesize them
# here with a distribution that has a few clear heavy hitters so
# eviction has a non-trivial decision to make.
kv_len = cache.seq_len(layer_idx)
new_mass = torch.rand(1, kv_len) * 0.1 # baseline noise
# Plant a few "heavy hitters" with much larger mass
heavy_idx = torch.randperm(kv_len)[: max(1, kv_len // 10)]
new_mass[0, heavy_idx] += torch.rand(len(heavy_idx)) * 0.9 + 0.5
if states[layer_idx].score is None or states[layer_idx].score.shape[-1] != kv_len:
states[layer_idx].score = new_mass
else:
states[layer_idx].score = states[layer_idx].score + new_mass
# Trigger eviction logic (this is the function the patched forward calls)
size_before = cache.seq_len(layer_idx)
_maybe_evict(cache, layer_idx, states[layer_idx])
size_after = cache.seq_len(layer_idx)
if size_after < size_before:
eviction_events += 1
sizes = [cache.seq_len(i) for i in range(num_layers)]
max_size = max(sizes)
avg_size = sum(sizes) / len(sizes)
rows.append({
"step": step,
"max_cache_size": max_size,
"avg_cache_size": round(avg_size, 1),
"expected_cap": expected_cap,
"over_cap": max(0, max_size - expected_cap),
"eviction_events_total": eviction_events,
})
if step % max(steps // 20, 1) == 0:
print(f" [step {step:>5d}/{steps}] max={max_size:>5d} avg={avg_size:.1f} events={eviction_events}")
elapsed = time.time() - t_start
print()
print(f"[B1] done: {len(rows)} steps in {elapsed:.1f}s ({len(rows)/elapsed:.1f} steps/sec)")
# Sanity assertions
final = rows[-1]
print(f"\n[B1] final state: max_cache={final['max_cache_size']} expected_cap={expected_cap} over_cap={final['over_cap']}")
assert final["max_cache_size"] <= expected_cap + 1, f"final cache exceeds cap: {final}"
print(f"[B1] PASS: cache stayed at or below expected cap throughout")
if eviction_events == 0 and steps > expected_cap + 10:
raise AssertionError(f"no eviction events observed despite running past cap")
print(f"[B1] PASS: {eviction_events} eviction events triggered correctly")
# Write CSV
with open(out_csv, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=["step", "max_cache_size", "avg_cache_size", "expected_cap", "over_cap", "eviction_events_total"])
writer.writeheader()
writer.writerows(rows)
print(f"\n[B1] wrote {len(rows)} rows -> {out_csv}")
print(f"[B1] memory model: at budget={budget} on canonical DeepseekV3 (61L, 64H, FP16):")
print(f" full cache @ 32K ctx: ~82 GB")
print(f" evicted cache @ {budget}: ~{61 * (budget + n_sink + n_recent) * 64 * (192+128) * 2 / 1e9:.1f} GB")
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--steps", type=int, default=1000)
ap.add_argument("--budget", type=int, default=64)
ap.add_argument("--n-sink", type=int, default=4)
ap.add_argument("--n-recent", type=int, default=16)
ap.add_argument("--num-layers", type=int, default=4)
ap.add_argument("--out-csv", type=Path, default=Path("results/validate_eviction_random_init.csv"))
args = ap.parse_args()
run_validation(
steps=args.steps,
budget=args.budget,
n_sink=args.n_sink,
n_recent=args.n_recent,
num_layers=args.num_layers,
out_csv=args.out_csv,
)
if __name__ == "__main__":
main()
|