File size: 3,761 Bytes
b786614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e249c3f
 
 
 
 
 
 
f044920
b786614
f044920
b786614
 
 
 
f044920
b786614
 
 
 
 
 
 
 
 
f044920
b786614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
press.py — KVPress-compatible wrapper for ProactiveCache eviction.

Implements the BasePress API from NVIDIA's kvpress library so that
ProactiveCache can be benchmarked directly against the 20+ methods
in the KVPress standard evaluation suite.

Usage (requires: pip install kvpress):
    from proactive_cache import ProactiveCachePress
    press = ProactiveCachePress(compression_ratio=0.75, prototype_path="...")
    # Use with kvpress evaluation harness
"""

from __future__ import annotations
import os
import pickle
import torch
import numpy as np
from dataclasses import dataclass
from typing import Optional

from .eviction import score_tokens


# Shim for Python 3.13+ which removed the 'pipes' module (needed by fire/kvpress)
try:
    import pipes
except ImportError:
    import sys, shlex
    sys.modules['pipes'] = shlex

# ── KVPress ScorerPress compatibility shim ────────────────────────────────────
try:
    from kvpress import ScorerPress
    _KVPRESS_AVAILABLE = True
except ImportError:
    _KVPRESS_AVAILABLE = False

    class ScorerPress:
        """Minimal shim — allows import without kvpress installed."""
        def __init__(self):
            self.compression_ratio = 0.0

        def score(self, module, hidden_states, keys, values, attentions, kwargs):
            raise NotImplementedError("Install kvpress: pip install kvpress")


@dataclass
class ProactiveCachePress(ScorerPress):
    """
    KVPress-compatible Proactive KV Cache eviction plugin.

    Implements the BasePress.score() hook, called once per attention layer
    during prefill. Returns a scalar importance score per token position —
    higher score = keep, lower score = evict (following KVPress convention).

    Args:
        compression_ratio: Fraction of tokens to EVICT [0.0, 1.0).
            e.g. 0.75 → keep 25% of the KV cache (budget = seq_len * 0.25).
        prototype_path: Path to a prototypes .pkl file from ``ProactiveCache.profile()``.
            If None, falls back to attention-sink + recency-only scoring.

    Example:
        press = ProactiveCachePress(compression_ratio=0.75, prototype_path="protos.pkl")
    """
    compression_ratio: float = 0.5
    prototype_path: Optional[str] = None

    def __post_init__(self):
        self._prototypes = None
        if self.prototype_path and os.path.exists(self.prototype_path):
            with open(self.prototype_path, "rb") as f:
                self._prototypes = pickle.load(f)
            print(f"[ProactiveCachePress] Loaded {len(self._prototypes)} prototypes "
                  f"from {self.prototype_path}")
        else:
            print("[ProactiveCachePress] No prototypes loaded — using sink+recency scoring.")

    def score(self, module, hidden_states, keys, values, attentions, kwargs):
        """
        KVPress hook: called once per attention layer during the prefill pass.

        Returns:
            scores: (batch, num_heads, seq_len) float tensor.
                    Higher = more important. KVPress will keep the top-K tokens
                    where K = seq_len * (1 - compression_ratio).
        """
        batch_size, num_heads, seq_len, head_dim = keys.shape
        budget = max(1, int(seq_len * (1.0 - self.compression_ratio)))
        device = keys.device

        # Build position scores (O(n), query-free)
        proto_scores = score_tokens(self._prototypes, seq_len, budget)
        proto_tensor = torch.tensor(proto_scores, dtype=torch.float32, device=device)

        # Broadcast (1, 1, seq_len) → (batch, num_heads, seq_len)
        scores = proto_tensor.unsqueeze(0).unsqueeze(0).expand(batch_size, num_heads, seq_len)
        return scores.contiguous()