""" Explicit KV Cache management for efficient inference. This is critical for Qualcomm deployment and agent control loops. """ import torch from typing import Optional, Tuple from dataclasses import dataclass @dataclass class KVCache: """Key-Value cache for transformer inference. Layout: [num_layers, batch_size, num_heads, max_seq_len, head_dim] This explicit cache enables: - Efficient autoregressive decoding - Cache offloading for memory management - Sliding window attention (future) - Agent control loops with cache manipulation """ key_cache: torch.Tensor # [num_layers, batch, heads, max_len, head_dim] value_cache: torch.Tensor # [num_layers, batch, heads, max_len, head_dim] seq_len: int # Current sequence length in cache @classmethod def create( cls, num_layers: int, batch_size: int, num_heads: int, max_seq_len: int, head_dim: int, dtype: torch.dtype = torch.float16, device: torch.device = None, ) -> "KVCache": """Create an empty KV cache. Args: num_layers: Number of transformer layers batch_size: Batch size num_heads: Number of attention heads max_seq_len: Maximum sequence length head_dim: Dimension per attention head dtype: Data type for cache tensors device: Device to create cache on Returns: Initialized KVCache with zero tensors """ shape = (num_layers, batch_size, num_heads, max_seq_len, head_dim) key_cache = torch.zeros(shape, dtype=dtype, device=device) value_cache = torch.zeros(shape, dtype=dtype, device=device) return cls(key_cache=key_cache, value_cache=value_cache, seq_len=0) def update( self, layer_idx: int, key: torch.Tensor, value: torch.Tensor, position: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Update cache for a specific layer and return full K, V. Args: layer_idx: Index of the transformer layer key: New key tensor [batch, heads, seq_len, head_dim] value: New value tensor [batch, heads, seq_len, head_dim] position: Starting position for the new tokens Returns: Tuple of (full_key, full_value) including cached values """ seq_len = key.shape[2] end_pos = position + seq_len # Store new keys and values self.key_cache[layer_idx, :, :, position:end_pos, :] = key self.value_cache[layer_idx, :, :, position:end_pos, :] = value # Update sequence length self.seq_len = max(self.seq_len, end_pos) # Return full K, V up to current position return ( self.key_cache[layer_idx, :, :, :end_pos, :], self.value_cache[layer_idx, :, :, :end_pos, :], ) def get( self, layer_idx: int, end_pos: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Get cached K, V for a specific layer. Args: layer_idx: Index of the transformer layer end_pos: End position (defaults to current seq_len) Returns: Tuple of (key, value) tensors """ if end_pos is None: end_pos = self.seq_len return ( self.key_cache[layer_idx, :, :, :end_pos, :], self.value_cache[layer_idx, :, :, :end_pos, :], ) def reset(self): """Reset the cache to empty state.""" self.key_cache.zero_() self.value_cache.zero_() self.seq_len = 0 @property def memory_usage_mb(self) -> float: """Calculate memory usage in megabytes.""" total_bytes = self.key_cache.numel() * self.key_cache.element_size() total_bytes += self.value_cache.numel() * self.value_cache.element_size() return total_bytes / (1024 * 1024)