File size: 9,759 Bytes
33efa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""
JaggedKVCache - Sparse KV Cache for Early Exit Inference.

This cache tracks per-layer sequence lengths, enabling efficient
generation with early exit heads that stop at different layers.
"""

import torch
from typing import List, Tuple, Optional


class JaggedKVCache:
    """
    Sparse KV Cache that tracks per-layer sequence lengths.

    Unlike standard KV caches where all layers have the same length,
    this cache allows different layers to have different cached lengths.
    This is essential for early exit inference where tokens may exit
    at different layers.

    Key features:
    - Per-layer KV storage with independent lengths
    - Lazy fill: missing positions are detected and can be computed on-demand
    - Truncation: efficient rollback on rejection
    - Cloning: snapshot for speculative drafting

    Attributes:
        num_layers: Total number of transformer layers
        batch_size: Batch size (typically 1 for inference)
        num_kv_heads: Number of key-value heads
        head_dim: Dimension of each head
        device: Device to store tensors on
        dtype: Data type for tensors
    """

    def __init__(
        self,
        num_layers: int,
        batch_size: int = 1,
        num_kv_heads: int = 8,
        head_dim: int = 128,
        device: str = "cpu",
        dtype: torch.dtype = torch.float32,
    ):
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.device = device
        self.dtype = dtype

        # Per-layer storage: List of (key_cache, value_cache) or None
        self.layer_caches: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
            None for _ in range(num_layers)
        ]

        # Track sequence length per layer (capacity = max_position + 1)
        self.layer_seq_lengths: List[int] = [0] * num_layers

        # Track which positions are actually filled (for lazy fill detection)
        # This is a list of sets, one per layer
        self.filled_positions: List[set] = [set() for _ in range(num_layers)]

    def update(
        self,
        layer_idx: int,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        cache_position: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Update cache for a layer at specific positions.

        Args:
            layer_idx: Layer index to update
            key_states: [B, num_kv_heads, seq_len, head_dim] new key states
            value_states: [B, num_kv_heads, seq_len, head_dim] new value states
            cache_position: [seq_len] tensor of positions to update

        Returns:
            (full_keys, full_values) tuple with all cached data
        """
        new_len = cache_position[-1].item() + 1
        input_seq_len = key_states.shape[2]
        positions = cache_position.tolist()

        if self.layer_caches[layer_idx] is None:
            # First time - check if positions are contiguous starting from 0
            if cache_position[0].item() == 0 and input_seq_len == new_len:
                # Simple case: positions [0, 1, ..., n-1] - just clone
                self.layer_caches[layer_idx] = (
                    key_states.clone(),
                    value_states.clone(),
                )
            else:
                # Non-contiguous or not starting from 0 - allocate full size
                k_cache = torch.zeros(
                    (self.batch_size, self.num_kv_heads, new_len, self.head_dim),
                    device=self.device,
                    dtype=self.dtype,
                )
                v_cache = torch.zeros(
                    (self.batch_size, self.num_kv_heads, new_len, self.head_dim),
                    device=self.device,
                    dtype=self.dtype,
                )
                k_cache[:, :, cache_position.long(), :] = key_states
                v_cache[:, :, cache_position.long(), :] = value_states
                self.layer_caches[layer_idx] = (k_cache, v_cache)

            self.layer_seq_lengths[layer_idx] = new_len
        else:
            k_cache, v_cache = self.layer_caches[layer_idx]
            current_len = k_cache.shape[2]

            if new_len > current_len:
                # Need to extend cache
                extension_size = new_len - current_len
                k_extension = torch.zeros(
                    (self.batch_size, self.num_kv_heads, extension_size, self.head_dim),
                    device=self.device,
                    dtype=self.dtype,
                )
                v_extension = torch.zeros(
                    (self.batch_size, self.num_kv_heads, extension_size, self.head_dim),
                    device=self.device,
                    dtype=self.dtype,
                )
                k_cache = torch.cat([k_cache, k_extension], dim=2)
                v_cache = torch.cat([v_cache, v_extension], dim=2)

            # Update at cache_position
            k_cache[:, :, cache_position.long(), :] = key_states
            v_cache[:, :, cache_position.long(), :] = value_states

            self.layer_caches[layer_idx] = (k_cache, v_cache)
            self.layer_seq_lengths[layer_idx] = max(
                self.layer_seq_lengths[layer_idx], new_len
            )

        # Track filled positions
        self.filled_positions[layer_idx].update(positions)

        return self.layer_caches[layer_idx]

    def get_kv(self, layer_idx: int) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
        """Get cached KV for a layer, or None if not cached."""
        return self.layer_caches[layer_idx]

    def get_seq_length(self, layer_idx: int) -> int:
        """Get the sequence length (capacity) for a layer."""
        return self.layer_seq_lengths[layer_idx]

    def has_position(self, layer_idx: int, position: int) -> bool:
        """Check if a specific position is filled for a layer."""
        return position in self.filled_positions[layer_idx]

    def get_unfilled_positions(self, layer_idx: int, up_to: int) -> List[int]:
        """Get list of positions that are not filled for a layer, up to `up_to` (exclusive)."""
        all_positions = set(range(up_to))
        filled = self.filled_positions[layer_idx]
        return sorted(all_positions - filled)

    def needs_fill(self, layer_idx: int, positions: List[int]) -> bool:
        """Check if any of the given positions need to be filled for a layer."""
        return not all(p in self.filled_positions[layer_idx] for p in positions)

    def get_missing_layers(self, position: int, target_layer: int) -> List[int]:
        """
        Get list of layers that need computation for a position.

        Args:
            position: The position we need KV for
            target_layer: The deepest layer we need to reach

        Returns:
            List of layer indices that need computation for this position
        """
        missing = []
        for layer_idx in range(target_layer + 1):
            if position not in self.filled_positions[layer_idx]:
                missing.append(layer_idx)
        return missing

    def truncate_from(self, position: int):
        """
        Truncate all layer caches from position onwards (exclusive).
        Used for rollback on rejection.

        Args:
            position: First position to remove (keeps 0..position-1)
        """
        for layer_idx in range(self.num_layers):
            if self.layer_caches[layer_idx] is not None:
                k, v = self.layer_caches[layer_idx]
                if k.shape[2] > position:
                    self.layer_caches[layer_idx] = (
                        k[:, :, :position, :].contiguous(),
                        v[:, :, :position, :].contiguous(),
                    )
                    self.layer_seq_lengths[layer_idx] = min(
                        self.layer_seq_lengths[layer_idx], position
                    )

            # Remove filled positions >= position
            self.filled_positions[layer_idx] = {
                p for p in self.filled_positions[layer_idx] if p < position
            }

    def clone(self) -> "JaggedKVCache":
        """
        Create a deep copy of the cache for speculative drafting.

        Returns:
            Independent copy that can be modified without affecting original
        """
        new_cache = JaggedKVCache(
            num_layers=self.num_layers,
            batch_size=self.batch_size,
            num_kv_heads=self.num_kv_heads,
            head_dim=self.head_dim,
            device=self.device,
            dtype=self.dtype,
        )
        for i, kv in enumerate(self.layer_caches):
            if kv is not None:
                new_cache.layer_caches[i] = (kv[0].clone(), kv[1].clone())
        new_cache.layer_seq_lengths = self.layer_seq_lengths.copy()
        new_cache.filled_positions = [s.copy() for s in self.filled_positions]
        return new_cache

    def reset(self):
        """Reset the cache to empty state."""
        self.layer_caches = [None for _ in range(self.num_layers)]
        self.layer_seq_lengths = [0] * self.num_layers
        self.filled_positions = [set() for _ in range(self.num_layers)]

    def __repr__(self) -> str:
        lines = [f"JaggedKVCache(num_layers={self.num_layers}, device={self.device})"]
        for i in range(min(self.num_layers, 10)):  # Show first 10 layers
            seq_len = self.layer_seq_lengths[i]
            filled = len(self.filled_positions[i])
            if seq_len > 0:
                lines.append(f"  Layer {i:2d}: {filled}/{seq_len} filled")
        if self.num_layers > 10:
            lines.append(f"  ... ({self.num_layers - 10} more layers)")
        return "\n".join(lines)