Spaces:
Sleeping
Sleeping
File size: 4,387 Bytes
b786614 10e074b b786614 10e074b b786614 10e074b b786614 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """
eviction.py — Token scoring and KV cache pruning.
Core O(n) eviction policy:
1. Score each token position using offline-profiled prototype centroids.
2. Keep the top-budget tokens (attention sink + recency anchors + semantic prototypes).
3. Prune the KV cache to exactly `budget` positions.
This is coordinate-free and RoPE-compatible: we only select positions, never
reorder them, so relative position encodings remain valid.
"""
from __future__ import annotations
import torch
import numpy as np
from typing import Optional, Dict, Tuple, List
from .utils import to_tuple_kv, to_dynamic_cache
def score_tokens(
prototypes: Optional[Dict],
seq_len: int,
budget: int,
) -> np.ndarray:
"""
Score all token positions using prototype centroid histograms.
Algorithm (O(n) per call):
- For each profiled (layer, head), accumulate the centroid attention
histogram as a distance-weighted score over token positions.
- Boost attention sink (token 0) unconditionally.
- Boost a proportional recency window at the tail.
- Add a small deterministic tiebreaker (position index).
Args:
prototypes: Output of ``build_prototypes()``. If None, falls back to
uniform scoring (no-op — keep all tokens equally).
seq_len: Current sequence length to score.
budget: Target number of tokens to keep.
Returns:
scores: (seq_len,) float64 array. Higher = more important.
"""
scores = np.zeros(seq_len, dtype=np.float64)
if prototypes is not None:
for (layer, head), data in prototypes.items():
centroid = data["centroids"][0] # shape: (profile_seq_len,)
max_d = min(len(centroid), seq_len)
if max_d == 0:
continue
cumsum = np.cumsum(centroid[:max_d])
for p in range(seq_len):
reach = min(max_d, seq_len - p)
if reach > 0:
scores[p] += cumsum[reach - 1]
# ── Robust Split-Budget Boosting (Sinks + 50% Recency + 50% Semantic) ─────
# Ensures perfect stability on relative position models (like LLaMA/RoPE)
# by guaranteeing a large contiguous local context window and a secure sink.
peak = scores.max() if scores.max() > 0 else 1.0
# 1. Boost Attention Sinks (first 4 tokens) securely
for i in range(min(4, seq_len)):
scores[i] += peak * 100.0
# 2. Boost Recency Window (50% of the budget) securely
recency_window = min(max(8, budget // 2), seq_len)
for i in range(recency_window):
scores[seq_len - 1 - i] += peak * 50.0
# ── Deterministic tiebreaker (prefer later tokens among equals) ───────────
scores += np.linspace(0, 1e-4, seq_len)
return scores
def select_indices(scores: np.ndarray, budget: int) -> List[int]:
"""Return the top-budget indices, sorted in ascending order (preserves sequence order)."""
actual_budget = min(budget, len(scores))
top = np.argsort(scores)[-actual_budget:]
return sorted(top.tolist())
def prune_kv_cache(
past_key_values,
indices: List[int],
device: torch.device,
):
"""
Prune a KV cache to the given token indices.
Args:
past_key_values: DynamicCache or legacy tuple from a model forward pass.
indices: Sorted list of token indices to keep.
device: CUDA/CPU device for the index tensor.
Returns:
Pruned KV cache in the same format the model expects
(DynamicCache if transformers ≥ 4.38, else tuple).
"""
idx_t = torch.tensor(indices, dtype=torch.long, device=device)
kv_tuple = to_tuple_kv(past_key_values)
pruned = tuple(
(k.index_select(2, idx_t), v.index_select(2, idx_t))
for k, v in kv_tuple
)
return to_dynamic_cache(pruned)
def evict(
past_key_values,
budget: int,
prototypes: Optional[Dict],
seq_len: int,
device: torch.device,
):
"""
One-shot eviction: score → select → prune.
If ``seq_len <= budget``, returns ``past_key_values`` unchanged.
"""
if seq_len <= budget:
return past_key_values
scores = score_tokens(prototypes, seq_len, budget)
indices = select_indices(scores, budget)
return prune_kv_cache(past_key_values, indices, device)
|