PebbleLM-117M / src /model /kv_cache.py
nameissakthi's picture
Add model architecture code
27871e7
"""
Explicit KV Cache management for efficient inference.
This is critical for Qualcomm deployment and agent control loops.
"""
import torch
from typing import Optional, Tuple
from dataclasses import dataclass
@dataclass
class KVCache:
"""Key-Value cache for transformer inference.
Layout: [num_layers, batch_size, num_heads, max_seq_len, head_dim]
This explicit cache enables:
- Efficient autoregressive decoding
- Cache offloading for memory management
- Sliding window attention (future)
- Agent control loops with cache manipulation
"""
key_cache: torch.Tensor # [num_layers, batch, heads, max_len, head_dim]
value_cache: torch.Tensor # [num_layers, batch, heads, max_len, head_dim]
seq_len: int # Current sequence length in cache
@classmethod
def create(
cls,
num_layers: int,
batch_size: int,
num_heads: int,
max_seq_len: int,
head_dim: int,
dtype: torch.dtype = torch.float16,
device: torch.device = None,
) -> "KVCache":
"""Create an empty KV cache.
Args:
num_layers: Number of transformer layers
batch_size: Batch size
num_heads: Number of attention heads
max_seq_len: Maximum sequence length
head_dim: Dimension per attention head
dtype: Data type for cache tensors
device: Device to create cache on
Returns:
Initialized KVCache with zero tensors
"""
shape = (num_layers, batch_size, num_heads, max_seq_len, head_dim)
key_cache = torch.zeros(shape, dtype=dtype, device=device)
value_cache = torch.zeros(shape, dtype=dtype, device=device)
return cls(key_cache=key_cache, value_cache=value_cache, seq_len=0)
def update(
self,
layer_idx: int,
key: torch.Tensor,
value: torch.Tensor,
position: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update cache for a specific layer and return full K, V.
Args:
layer_idx: Index of the transformer layer
key: New key tensor [batch, heads, seq_len, head_dim]
value: New value tensor [batch, heads, seq_len, head_dim]
position: Starting position for the new tokens
Returns:
Tuple of (full_key, full_value) including cached values
"""
seq_len = key.shape[2]
end_pos = position + seq_len
# Store new keys and values
self.key_cache[layer_idx, :, :, position:end_pos, :] = key
self.value_cache[layer_idx, :, :, position:end_pos, :] = value
# Update sequence length
self.seq_len = max(self.seq_len, end_pos)
# Return full K, V up to current position
return (
self.key_cache[layer_idx, :, :, :end_pos, :],
self.value_cache[layer_idx, :, :, :end_pos, :],
)
def get(
self,
layer_idx: int,
end_pos: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get cached K, V for a specific layer.
Args:
layer_idx: Index of the transformer layer
end_pos: End position (defaults to current seq_len)
Returns:
Tuple of (key, value) tensors
"""
if end_pos is None:
end_pos = self.seq_len
return (
self.key_cache[layer_idx, :, :, :end_pos, :],
self.value_cache[layer_idx, :, :, :end_pos, :],
)
def reset(self):
"""Reset the cache to empty state."""
self.key_cache.zero_()
self.value_cache.zero_()
self.seq_len = 0
@property
def memory_usage_mb(self) -> float:
"""Calculate memory usage in megabytes."""
total_bytes = self.key_cache.numel() * self.key_cache.element_size()
total_bytes += self.value_cache.numel() * self.value_cache.element_size()
return total_bytes / (1024 * 1024)