|
|
""" |
|
|
Explicit KV Cache management for efficient inference. |
|
|
This is critical for Qualcomm deployment and agent control loops. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from typing import Optional, Tuple |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class KVCache: |
|
|
"""Key-Value cache for transformer inference. |
|
|
|
|
|
Layout: [num_layers, batch_size, num_heads, max_seq_len, head_dim] |
|
|
|
|
|
This explicit cache enables: |
|
|
- Efficient autoregressive decoding |
|
|
- Cache offloading for memory management |
|
|
- Sliding window attention (future) |
|
|
- Agent control loops with cache manipulation |
|
|
""" |
|
|
|
|
|
key_cache: torch.Tensor |
|
|
value_cache: torch.Tensor |
|
|
seq_len: int |
|
|
|
|
|
@classmethod |
|
|
def create( |
|
|
cls, |
|
|
num_layers: int, |
|
|
batch_size: int, |
|
|
num_heads: int, |
|
|
max_seq_len: int, |
|
|
head_dim: int, |
|
|
dtype: torch.dtype = torch.float16, |
|
|
device: torch.device = None, |
|
|
) -> "KVCache": |
|
|
"""Create an empty KV cache. |
|
|
|
|
|
Args: |
|
|
num_layers: Number of transformer layers |
|
|
batch_size: Batch size |
|
|
num_heads: Number of attention heads |
|
|
max_seq_len: Maximum sequence length |
|
|
head_dim: Dimension per attention head |
|
|
dtype: Data type for cache tensors |
|
|
device: Device to create cache on |
|
|
|
|
|
Returns: |
|
|
Initialized KVCache with zero tensors |
|
|
""" |
|
|
shape = (num_layers, batch_size, num_heads, max_seq_len, head_dim) |
|
|
|
|
|
key_cache = torch.zeros(shape, dtype=dtype, device=device) |
|
|
value_cache = torch.zeros(shape, dtype=dtype, device=device) |
|
|
|
|
|
return cls(key_cache=key_cache, value_cache=value_cache, seq_len=0) |
|
|
|
|
|
def update( |
|
|
self, |
|
|
layer_idx: int, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
position: int, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Update cache for a specific layer and return full K, V. |
|
|
|
|
|
Args: |
|
|
layer_idx: Index of the transformer layer |
|
|
key: New key tensor [batch, heads, seq_len, head_dim] |
|
|
value: New value tensor [batch, heads, seq_len, head_dim] |
|
|
position: Starting position for the new tokens |
|
|
|
|
|
Returns: |
|
|
Tuple of (full_key, full_value) including cached values |
|
|
""" |
|
|
seq_len = key.shape[2] |
|
|
end_pos = position + seq_len |
|
|
|
|
|
|
|
|
self.key_cache[layer_idx, :, :, position:end_pos, :] = key |
|
|
self.value_cache[layer_idx, :, :, position:end_pos, :] = value |
|
|
|
|
|
|
|
|
self.seq_len = max(self.seq_len, end_pos) |
|
|
|
|
|
|
|
|
return ( |
|
|
self.key_cache[layer_idx, :, :, :end_pos, :], |
|
|
self.value_cache[layer_idx, :, :, :end_pos, :], |
|
|
) |
|
|
|
|
|
def get( |
|
|
self, |
|
|
layer_idx: int, |
|
|
end_pos: Optional[int] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Get cached K, V for a specific layer. |
|
|
|
|
|
Args: |
|
|
layer_idx: Index of the transformer layer |
|
|
end_pos: End position (defaults to current seq_len) |
|
|
|
|
|
Returns: |
|
|
Tuple of (key, value) tensors |
|
|
""" |
|
|
if end_pos is None: |
|
|
end_pos = self.seq_len |
|
|
|
|
|
return ( |
|
|
self.key_cache[layer_idx, :, :, :end_pos, :], |
|
|
self.value_cache[layer_idx, :, :, :end_pos, :], |
|
|
) |
|
|
|
|
|
def reset(self): |
|
|
"""Reset the cache to empty state.""" |
|
|
self.key_cache.zero_() |
|
|
self.value_cache.zero_() |
|
|
self.seq_len = 0 |
|
|
|
|
|
@property |
|
|
def memory_usage_mb(self) -> float: |
|
|
"""Calculate memory usage in megabytes.""" |
|
|
total_bytes = self.key_cache.numel() * self.key_cache.element_size() |
|
|
total_bytes += self.value_cache.numel() * self.value_cache.element_size() |
|
|
return total_bytes / (1024 * 1024) |
|
|
|