File size: 4,387 Bytes
b786614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10e074b
 
 
b786614
 
10e074b
 
 
 
 
 
b786614
10e074b
b786614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
eviction.py — Token scoring and KV cache pruning.

Core O(n) eviction policy:
  1. Score each token position using offline-profiled prototype centroids.
  2. Keep the top-budget tokens (attention sink + recency anchors + semantic prototypes).
  3. Prune the KV cache to exactly `budget` positions.

This is coordinate-free and RoPE-compatible: we only select positions, never
reorder them, so relative position encodings remain valid.
"""

from __future__ import annotations
import torch
import numpy as np
from typing import Optional, Dict, Tuple, List

from .utils import to_tuple_kv, to_dynamic_cache


def score_tokens(
    prototypes: Optional[Dict],
    seq_len: int,
    budget: int,
) -> np.ndarray:
    """
    Score all token positions using prototype centroid histograms.

    Algorithm (O(n) per call):
      - For each profiled (layer, head), accumulate the centroid attention
        histogram as a distance-weighted score over token positions.
      - Boost attention sink (token 0) unconditionally.
      - Boost a proportional recency window at the tail.
      - Add a small deterministic tiebreaker (position index).

    Args:
        prototypes: Output of ``build_prototypes()``. If None, falls back to
                    uniform scoring (no-op — keep all tokens equally).
        seq_len:    Current sequence length to score.
        budget:     Target number of tokens to keep.

    Returns:
        scores: (seq_len,) float64 array. Higher = more important.
    """
    scores = np.zeros(seq_len, dtype=np.float64)

    if prototypes is not None:
        for (layer, head), data in prototypes.items():
            centroid = data["centroids"][0]            # shape: (profile_seq_len,)
            max_d = min(len(centroid), seq_len)
            if max_d == 0:
                continue
            cumsum = np.cumsum(centroid[:max_d])
            for p in range(seq_len):
                reach = min(max_d, seq_len - p)
                if reach > 0:
                    scores[p] += cumsum[reach - 1]

    # ── Robust Split-Budget Boosting (Sinks + 50% Recency + 50% Semantic) ─────
    # Ensures perfect stability on relative position models (like LLaMA/RoPE)
    # by guaranteeing a large contiguous local context window and a secure sink.
    peak = scores.max() if scores.max() > 0 else 1.0

    # 1. Boost Attention Sinks (first 4 tokens) securely
    for i in range(min(4, seq_len)):
        scores[i] += peak * 100.0

    # 2. Boost Recency Window (50% of the budget) securely
    recency_window = min(max(8, budget // 2), seq_len)
    for i in range(recency_window):
        scores[seq_len - 1 - i] += peak * 50.0

    # ── Deterministic tiebreaker (prefer later tokens among equals) ───────────
    scores += np.linspace(0, 1e-4, seq_len)

    return scores


def select_indices(scores: np.ndarray, budget: int) -> List[int]:
    """Return the top-budget indices, sorted in ascending order (preserves sequence order)."""
    actual_budget = min(budget, len(scores))
    top = np.argsort(scores)[-actual_budget:]
    return sorted(top.tolist())


def prune_kv_cache(
    past_key_values,
    indices: List[int],
    device: torch.device,
):
    """
    Prune a KV cache to the given token indices.

    Args:
        past_key_values: DynamicCache or legacy tuple from a model forward pass.
        indices:         Sorted list of token indices to keep.
        device:          CUDA/CPU device for the index tensor.

    Returns:
        Pruned KV cache in the same format the model expects
        (DynamicCache if transformers ≥ 4.38, else tuple).
    """
    idx_t = torch.tensor(indices, dtype=torch.long, device=device)
    kv_tuple = to_tuple_kv(past_key_values)
    pruned = tuple(
        (k.index_select(2, idx_t), v.index_select(2, idx_t))
        for k, v in kv_tuple
    )
    return to_dynamic_cache(pruned)


def evict(
    past_key_values,
    budget: int,
    prototypes: Optional[Dict],
    seq_len: int,
    device: torch.device,
):
    """
    One-shot eviction: score → select → prune.

    If ``seq_len <= budget``, returns ``past_key_values`` unchanged.
    """
    if seq_len <= budget:
        return past_key_values

    scores = score_tokens(prototypes, seq_len, budget)
    indices = select_indices(scores, budget)
    return prune_kv_cache(past_key_values, indices, device)