File size: 8,096 Bytes
1ba26d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""B1: comprehensive multi-step validation of the H2O eviction logic on a mock
KV cache that mirrors the canonical DeepseekV3 cache structure.

This script extends the single-shot smoke test in src/kv_eviction_mla.py to
thousands of simulated generation steps, captures per-step cache size, and
emits a CSV showing the eviction mechanism stabilizes the cache at the
expected bound.

Why a mock cache instead of a full model:
    The patch in src/kv_eviction_mla.py is currently aligned with the
    transformers 4.x KV cache API (DynamicCache.key_cache / value_cache lists).
    transformers 5.x reorganized the cache into DynamicCache.layers[i], so the
    patch needs porting before it runs end-to-end on transformers 5.x. The
    eviction *logic* is unchanged across transformers versions; the API plumbing
    is what differs. This script validates the logic; the plumbing port is on
    the roadmap.

Run:
    python scripts/validate_eviction_random_init.py \
        --steps 1000 --budget 64 --n-sink 4 --n-recent 16 \
        --out-csv results/validate_eviction_random_init.csv

Expected output:
    - For steps 1..(n_sink + budget + n_recent): cache grows linearly.
    - For later steps: cache size stays at exactly (n_sink + budget + n_recent).
    - Eviction events are logged each time the cache crosses the threshold.
"""
from __future__ import annotations
import argparse
import csv
import sys
import time
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))

import torch

# Import the eviction state class and the eviction function from the module
# under test.
from kv_eviction_mla import _EvictionState, _maybe_evict


class MockMLAcache:
    """Stand-in for transformers DynamicCache that exposes the same shape
    contract the eviction code uses (key_cache[i], value_cache[i] slicable
    along the seq dimension).

    For DeepseekV3 / MLA, qk_dim != v_dim, so K and V have different head
    dimensions. We mirror that here.
    """

    def __init__(self, num_layers: int, batch: int, heads: int, qk_dim: int, v_dim: int, device: str = "cpu"):
        self.key_cache = [
            torch.zeros(batch, heads, 0, qk_dim, device=device) for _ in range(num_layers)
        ]
        self.value_cache = [
            torch.zeros(batch, heads, 0, v_dim, device=device) for _ in range(num_layers)
        ]
        self._batch = batch
        self._heads = heads
        self._qk = qk_dim
        self._v = v_dim
        self._device = device

    def append_token(self, layer_idx: int) -> None:
        """Simulate one new token landing in this layer's cache."""
        new_k = torch.randn(self._batch, self._heads, 1, self._qk, device=self._device)
        new_v = torch.randn(self._batch, self._heads, 1, self._v, device=self._device)
        self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], new_k], dim=2)
        self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], new_v], dim=2)

    def seq_len(self, layer_idx: int) -> int:
        return int(self.key_cache[layer_idx].shape[2])


def run_validation(
    steps: int = 1000,
    budget: int = 64,
    n_sink: int = 4,
    n_recent: int = 16,
    num_layers: int = 4,
    out_csv: Path = Path("results/validate_eviction_random_init.csv"),
) -> None:
    print(f"[B1] mock cache: {num_layers} layers, canonical DeepseekV3 dims (qk=192, v=128)")
    cache = MockMLAcache(num_layers=num_layers, batch=1, heads=4, qk_dim=192, v_dim=128)

    # One eviction state per layer, as install_kv_eviction would create
    states = [
        _EvictionState(budget=budget, n_sink=n_sink, n_recent=n_recent, evict_every=1)
        for _ in range(num_layers)
    ]

    expected_cap = n_sink + budget + n_recent
    print(f"[B1] eviction config: budget={budget} n_sink={n_sink} n_recent={n_recent}")
    print(f"[B1] expected cache cap per layer: {expected_cap} tokens")
    print(f"[B1] running {steps} simulated generation steps...")
    print()

    out_csv.parent.mkdir(parents=True, exist_ok=True)
    rows = []
    eviction_events = 0
    t_start = time.time()

    for step in range(steps):
        # Simulate one generation step: each layer gets one new token appended.
        for layer_idx in range(num_layers):
            cache.append_token(layer_idx)

            # Update accumulated importance scores. In the real patch, these
            # come from attn_weights at every forward call. We synthesize them
            # here with a distribution that has a few clear heavy hitters so
            # eviction has a non-trivial decision to make.
            kv_len = cache.seq_len(layer_idx)
            new_mass = torch.rand(1, kv_len) * 0.1  # baseline noise
            # Plant a few "heavy hitters" with much larger mass
            heavy_idx = torch.randperm(kv_len)[: max(1, kv_len // 10)]
            new_mass[0, heavy_idx] += torch.rand(len(heavy_idx)) * 0.9 + 0.5
            if states[layer_idx].score is None or states[layer_idx].score.shape[-1] != kv_len:
                states[layer_idx].score = new_mass
            else:
                states[layer_idx].score = states[layer_idx].score + new_mass

            # Trigger eviction logic (this is the function the patched forward calls)
            size_before = cache.seq_len(layer_idx)
            _maybe_evict(cache, layer_idx, states[layer_idx])
            size_after = cache.seq_len(layer_idx)

            if size_after < size_before:
                eviction_events += 1

        sizes = [cache.seq_len(i) for i in range(num_layers)]
        max_size = max(sizes)
        avg_size = sum(sizes) / len(sizes)

        rows.append({
            "step": step,
            "max_cache_size": max_size,
            "avg_cache_size": round(avg_size, 1),
            "expected_cap": expected_cap,
            "over_cap": max(0, max_size - expected_cap),
            "eviction_events_total": eviction_events,
        })

        if step % max(steps // 20, 1) == 0:
            print(f"  [step {step:>5d}/{steps}] max={max_size:>5d}  avg={avg_size:.1f}  events={eviction_events}")

    elapsed = time.time() - t_start
    print()
    print(f"[B1] done: {len(rows)} steps in {elapsed:.1f}s ({len(rows)/elapsed:.1f} steps/sec)")

    # Sanity assertions
    final = rows[-1]
    print(f"\n[B1] final state: max_cache={final['max_cache_size']}  expected_cap={expected_cap}  over_cap={final['over_cap']}")
    assert final["max_cache_size"] <= expected_cap + 1, f"final cache exceeds cap: {final}"
    print(f"[B1] PASS: cache stayed at or below expected cap throughout")

    if eviction_events == 0 and steps > expected_cap + 10:
        raise AssertionError(f"no eviction events observed despite running past cap")
    print(f"[B1] PASS: {eviction_events} eviction events triggered correctly")

    # Write CSV
    with open(out_csv, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=["step", "max_cache_size", "avg_cache_size", "expected_cap", "over_cap", "eviction_events_total"])
        writer.writeheader()
        writer.writerows(rows)
    print(f"\n[B1] wrote {len(rows)} rows -> {out_csv}")
    print(f"[B1] memory model: at budget={budget} on canonical DeepseekV3 (61L, 64H, FP16):")
    print(f"     full cache @ 32K ctx:  ~82 GB")
    print(f"     evicted cache @ {budget}: ~{61 * (budget + n_sink + n_recent) * 64 * (192+128) * 2 / 1e9:.1f} GB")


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--steps", type=int, default=1000)
    ap.add_argument("--budget", type=int, default=64)
    ap.add_argument("--n-sink", type=int, default=4)
    ap.add_argument("--n-recent", type=int, default=16)
    ap.add_argument("--num-layers", type=int, default=4)
    ap.add_argument("--out-csv", type=Path, default=Path("results/validate_eviction_random_init.csv"))
    args = ap.parse_args()
    run_validation(
        steps=args.steps,
        budget=args.budget,
        n_sink=args.n_sink,
        n_recent=args.n_recent,
        num_layers=args.num_layers,
        out_csv=args.out_csv,
    )


if __name__ == "__main__":
    main()