|
|
""" |
|
|
JaggedKVCache - Sparse KV Cache for Early Exit Inference. |
|
|
|
|
|
This cache tracks per-layer sequence lengths, enabling efficient |
|
|
generation with early exit heads that stop at different layers. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from typing import List, Tuple, Optional |
|
|
|
|
|
|
|
|
class JaggedKVCache: |
|
|
""" |
|
|
Sparse KV Cache that tracks per-layer sequence lengths. |
|
|
|
|
|
Unlike standard KV caches where all layers have the same length, |
|
|
this cache allows different layers to have different cached lengths. |
|
|
This is essential for early exit inference where tokens may exit |
|
|
at different layers. |
|
|
|
|
|
Key features: |
|
|
- Per-layer KV storage with independent lengths |
|
|
- Lazy fill: missing positions are detected and can be computed on-demand |
|
|
- Truncation: efficient rollback on rejection |
|
|
- Cloning: snapshot for speculative drafting |
|
|
|
|
|
Attributes: |
|
|
num_layers: Total number of transformer layers |
|
|
batch_size: Batch size (typically 1 for inference) |
|
|
num_kv_heads: Number of key-value heads |
|
|
head_dim: Dimension of each head |
|
|
device: Device to store tensors on |
|
|
dtype: Data type for tensors |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_layers: int, |
|
|
batch_size: int = 1, |
|
|
num_kv_heads: int = 8, |
|
|
head_dim: int = 128, |
|
|
device: str = "cpu", |
|
|
dtype: torch.dtype = torch.float32, |
|
|
): |
|
|
self.num_layers = num_layers |
|
|
self.batch_size = batch_size |
|
|
self.num_kv_heads = num_kv_heads |
|
|
self.head_dim = head_dim |
|
|
self.device = device |
|
|
self.dtype = dtype |
|
|
|
|
|
|
|
|
self.layer_caches: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [ |
|
|
None for _ in range(num_layers) |
|
|
] |
|
|
|
|
|
|
|
|
self.layer_seq_lengths: List[int] = [0] * num_layers |
|
|
|
|
|
|
|
|
|
|
|
self.filled_positions: List[set] = [set() for _ in range(num_layers)] |
|
|
|
|
|
def update( |
|
|
self, |
|
|
layer_idx: int, |
|
|
key_states: torch.Tensor, |
|
|
value_states: torch.Tensor, |
|
|
cache_position: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Update cache for a layer at specific positions. |
|
|
|
|
|
Args: |
|
|
layer_idx: Layer index to update |
|
|
key_states: [B, num_kv_heads, seq_len, head_dim] new key states |
|
|
value_states: [B, num_kv_heads, seq_len, head_dim] new value states |
|
|
cache_position: [seq_len] tensor of positions to update |
|
|
|
|
|
Returns: |
|
|
(full_keys, full_values) tuple with all cached data |
|
|
""" |
|
|
new_len = cache_position[-1].item() + 1 |
|
|
input_seq_len = key_states.shape[2] |
|
|
positions = cache_position.tolist() |
|
|
|
|
|
if self.layer_caches[layer_idx] is None: |
|
|
|
|
|
if cache_position[0].item() == 0 and input_seq_len == new_len: |
|
|
|
|
|
self.layer_caches[layer_idx] = ( |
|
|
key_states.clone(), |
|
|
value_states.clone(), |
|
|
) |
|
|
else: |
|
|
|
|
|
k_cache = torch.zeros( |
|
|
(self.batch_size, self.num_kv_heads, new_len, self.head_dim), |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
v_cache = torch.zeros( |
|
|
(self.batch_size, self.num_kv_heads, new_len, self.head_dim), |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
k_cache[:, :, cache_position.long(), :] = key_states |
|
|
v_cache[:, :, cache_position.long(), :] = value_states |
|
|
self.layer_caches[layer_idx] = (k_cache, v_cache) |
|
|
|
|
|
self.layer_seq_lengths[layer_idx] = new_len |
|
|
else: |
|
|
k_cache, v_cache = self.layer_caches[layer_idx] |
|
|
current_len = k_cache.shape[2] |
|
|
|
|
|
if new_len > current_len: |
|
|
|
|
|
extension_size = new_len - current_len |
|
|
k_extension = torch.zeros( |
|
|
(self.batch_size, self.num_kv_heads, extension_size, self.head_dim), |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
v_extension = torch.zeros( |
|
|
(self.batch_size, self.num_kv_heads, extension_size, self.head_dim), |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
k_cache = torch.cat([k_cache, k_extension], dim=2) |
|
|
v_cache = torch.cat([v_cache, v_extension], dim=2) |
|
|
|
|
|
|
|
|
k_cache[:, :, cache_position.long(), :] = key_states |
|
|
v_cache[:, :, cache_position.long(), :] = value_states |
|
|
|
|
|
self.layer_caches[layer_idx] = (k_cache, v_cache) |
|
|
self.layer_seq_lengths[layer_idx] = max( |
|
|
self.layer_seq_lengths[layer_idx], new_len |
|
|
) |
|
|
|
|
|
|
|
|
self.filled_positions[layer_idx].update(positions) |
|
|
|
|
|
return self.layer_caches[layer_idx] |
|
|
|
|
|
def get_kv(self, layer_idx: int) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: |
|
|
"""Get cached KV for a layer, or None if not cached.""" |
|
|
return self.layer_caches[layer_idx] |
|
|
|
|
|
def get_seq_length(self, layer_idx: int) -> int: |
|
|
"""Get the sequence length (capacity) for a layer.""" |
|
|
return self.layer_seq_lengths[layer_idx] |
|
|
|
|
|
def has_position(self, layer_idx: int, position: int) -> bool: |
|
|
"""Check if a specific position is filled for a layer.""" |
|
|
return position in self.filled_positions[layer_idx] |
|
|
|
|
|
def get_unfilled_positions(self, layer_idx: int, up_to: int) -> List[int]: |
|
|
"""Get list of positions that are not filled for a layer, up to `up_to` (exclusive).""" |
|
|
all_positions = set(range(up_to)) |
|
|
filled = self.filled_positions[layer_idx] |
|
|
return sorted(all_positions - filled) |
|
|
|
|
|
def needs_fill(self, layer_idx: int, positions: List[int]) -> bool: |
|
|
"""Check if any of the given positions need to be filled for a layer.""" |
|
|
return not all(p in self.filled_positions[layer_idx] for p in positions) |
|
|
|
|
|
def get_missing_layers(self, position: int, target_layer: int) -> List[int]: |
|
|
""" |
|
|
Get list of layers that need computation for a position. |
|
|
|
|
|
Args: |
|
|
position: The position we need KV for |
|
|
target_layer: The deepest layer we need to reach |
|
|
|
|
|
Returns: |
|
|
List of layer indices that need computation for this position |
|
|
""" |
|
|
missing = [] |
|
|
for layer_idx in range(target_layer + 1): |
|
|
if position not in self.filled_positions[layer_idx]: |
|
|
missing.append(layer_idx) |
|
|
return missing |
|
|
|
|
|
def truncate_from(self, position: int): |
|
|
""" |
|
|
Truncate all layer caches from position onwards (exclusive). |
|
|
Used for rollback on rejection. |
|
|
|
|
|
Args: |
|
|
position: First position to remove (keeps 0..position-1) |
|
|
""" |
|
|
for layer_idx in range(self.num_layers): |
|
|
if self.layer_caches[layer_idx] is not None: |
|
|
k, v = self.layer_caches[layer_idx] |
|
|
if k.shape[2] > position: |
|
|
self.layer_caches[layer_idx] = ( |
|
|
k[:, :, :position, :].contiguous(), |
|
|
v[:, :, :position, :].contiguous(), |
|
|
) |
|
|
self.layer_seq_lengths[layer_idx] = min( |
|
|
self.layer_seq_lengths[layer_idx], position |
|
|
) |
|
|
|
|
|
|
|
|
self.filled_positions[layer_idx] = { |
|
|
p for p in self.filled_positions[layer_idx] if p < position |
|
|
} |
|
|
|
|
|
def clone(self) -> "JaggedKVCache": |
|
|
""" |
|
|
Create a deep copy of the cache for speculative drafting. |
|
|
|
|
|
Returns: |
|
|
Independent copy that can be modified without affecting original |
|
|
""" |
|
|
new_cache = JaggedKVCache( |
|
|
num_layers=self.num_layers, |
|
|
batch_size=self.batch_size, |
|
|
num_kv_heads=self.num_kv_heads, |
|
|
head_dim=self.head_dim, |
|
|
device=self.device, |
|
|
dtype=self.dtype, |
|
|
) |
|
|
for i, kv in enumerate(self.layer_caches): |
|
|
if kv is not None: |
|
|
new_cache.layer_caches[i] = (kv[0].clone(), kv[1].clone()) |
|
|
new_cache.layer_seq_lengths = self.layer_seq_lengths.copy() |
|
|
new_cache.filled_positions = [s.copy() for s in self.filled_positions] |
|
|
return new_cache |
|
|
|
|
|
def reset(self): |
|
|
"""Reset the cache to empty state.""" |
|
|
self.layer_caches = [None for _ in range(self.num_layers)] |
|
|
self.layer_seq_lengths = [0] * self.num_layers |
|
|
self.filled_positions = [set() for _ in range(self.num_layers)] |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
lines = [f"JaggedKVCache(num_layers={self.num_layers}, device={self.device})"] |
|
|
for i in range(min(self.num_layers, 10)): |
|
|
seq_len = self.layer_seq_lengths[i] |
|
|
filled = len(self.filled_positions[i]) |
|
|
if seq_len > 0: |
|
|
lines.append(f" Layer {i:2d}: {filled}/{seq_len} filled") |
|
|
if self.num_layers > 10: |
|
|
lines.append(f" ... ({self.num_layers - 10} more layers)") |
|
|
return "\n".join(lines) |
|
|
|