from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Sequence import numpy as np from ...config import DotCacheConfig from ...encode import encode_page from ...model_kv_cache import ModelPagedKVCache, default_q_head_to_kv_head from ...page_cache import PreparedPageCache from ...tracing import ExecutionTrace from ...types import EncodedPage @dataclass(frozen=True, slots=True) class VllmBlockKey: layer_id: int kv_head_id: int block_id: int kind: str @dataclass(slots=True) class VllmBlockEntry: key: VllmBlockKey page: EncodedPage finalized: bool token_count: int token_start: int @dataclass(slots=True) class _LiveBlockState: block_id: int | None = None token_start: int | None = None key_rows: list[np.ndarray] = field(default_factory=list) value_rows: list[np.ndarray] = field(default_factory=list) def clear(self) -> None: self.block_id = None self.token_start = None self.key_rows.clear() self.value_rows.clear() def _normalize_block_tensor( values: np.ndarray, *, num_key_value_heads: int, block_size: int, head_dim: int, name: str, ) -> np.ndarray: array = np.asarray(values, dtype=np.float32) if array.ndim == 5: if array.shape[0] != 1: raise ValueError(f"{name} batch dimension must be 1") array = array[0] if array.ndim != 4: raise ValueError(f"{name} must have shape [kv_heads, block_count, block_size, head_dim]") if int(array.shape[0]) != num_key_value_heads: raise ValueError(f"{name} must contain {num_key_value_heads} KV heads") if int(array.shape[2]) != block_size: raise ValueError(f"{name} block size must equal {block_size}") if int(array.shape[3]) != head_dim: raise ValueError(f"{name} head_dim must equal {head_dim}") return array def _normalize_step_tensor( values: np.ndarray, *, num_key_value_heads: int, head_dim: int, name: str, ) -> np.ndarray: array = np.asarray(values, dtype=np.float32) if array.ndim == 4: if array.shape[0] != 1: raise ValueError(f"{name} batch dimension must be 1") array = array[0] if array.ndim != 3: raise ValueError(f"{name} must have shape [kv_heads, token_count, head_dim]") if int(array.shape[0]) != num_key_value_heads: raise ValueError(f"{name} must contain {num_key_value_heads} KV heads") if int(array.shape[2]) != head_dim: raise ValueError(f"{name} head_dim must equal {head_dim}") return array class VllmPagedKVCache: def __init__( self, *, config: DotCacheConfig, num_hidden_layers: int, num_attention_heads: int, num_key_value_heads: int, block_size: int, backend: str = "torch_cuda", cache: PreparedPageCache | None = None, ) -> None: if config.tokens_per_page != block_size: raise ValueError("DotCache tokens_per_page must equal the vLLM block_size for this phase") self.config = config self.block_size = int(block_size) self.num_hidden_layers = int(num_hidden_layers) self.num_attention_heads = int(num_attention_heads) self.num_key_value_heads = int(num_key_value_heads) self.backend = backend self.cache = cache if cache is not None else PreparedPageCache() self.model_kv_cache = ModelPagedKVCache( config=config, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, backend=backend, cache=self.cache, ) self.default_q_head_to_kv_head = default_q_head_to_kv_head(num_attention_heads, num_key_value_heads) self._blocks: dict[VllmBlockKey, VllmBlockEntry] = {} self._live_states: dict[tuple[int, int], _LiveBlockState] = {} @property def resident_bytes(self) -> int: return self.model_kv_cache.resident_bytes def clear(self) -> None: self._blocks.clear() self._live_states.clear() self.model_kv_cache.clear() def block_entry(self, layer_id: int, kv_head_id: int, block_id: int, kind: str) -> VllmBlockEntry: key = VllmBlockKey(int(layer_id), int(kv_head_id), int(block_id), kind) return self._blocks[key] def block_entries_for_layer(self, layer_id: int, *, kind: str) -> list[VllmBlockEntry]: return sorted( [entry for key, entry in self._blocks.items() if key.layer_id == layer_id and key.kind == kind], key=lambda entry: (entry.key.kv_head_id, entry.key.block_id), ) def _remove_layer_blocks(self, layer_id: int) -> None: stale_keys = [key for key in self._blocks if key.layer_id == layer_id] for key in stale_keys: del self._blocks[key] stale_live = [key for key in self._live_states if key[0] == layer_id] for key in stale_live: del self._live_states[key] def sync_layer_blocks( self, layer_id: int, key_blocks: np.ndarray, value_blocks: np.ndarray, *, block_ids: Sequence[int] | None = None, live_block_token_count: int = 0, trace: ExecutionTrace | None = None, ) -> None: keys = _normalize_block_tensor( key_blocks, num_key_value_heads=self.num_key_value_heads, block_size=self.block_size, head_dim=self.config.head_dim, name="key_blocks", ) values = _normalize_block_tensor( value_blocks, num_key_value_heads=self.num_key_value_heads, block_size=self.block_size, head_dim=self.config.head_dim, name="value_blocks", ) if keys.shape[1] != values.shape[1]: raise ValueError("key_blocks and value_blocks must contain the same number of blocks") block_count = int(keys.shape[1]) resolved_block_ids = tuple(range(block_count)) if block_ids is None else tuple(int(block_id) for block_id in block_ids) if len(resolved_block_ids) != block_count: raise ValueError("block_ids must align with the number of blocks") if len(set(resolved_block_ids)) != len(resolved_block_ids): raise ValueError("block_ids must be unique") if live_block_token_count < 0 or live_block_token_count > self.block_size: raise ValueError("live_block_token_count must be in [0, block_size]") self._remove_layer_blocks(layer_id) self.model_kv_cache.clear_layer(layer_id) if block_count == 0: return finalized_block_count = block_count if live_block_token_count in (0, self.block_size) else block_count - 1 full_tokens = finalized_block_count * self.block_size dense_keys = keys[:, :finalized_block_count].reshape( self.num_key_value_heads, full_tokens, self.config.head_dim, ) if finalized_block_count > 0 else np.zeros((self.num_key_value_heads, 0, self.config.head_dim), dtype=np.float32) dense_values = values[:, :finalized_block_count].reshape( self.num_key_value_heads, full_tokens, self.config.head_dim, ) if finalized_block_count > 0 else np.zeros((self.num_key_value_heads, 0, self.config.head_dim), dtype=np.float32) if finalized_block_count < block_count: live_keys = keys[:, finalized_block_count, :live_block_token_count] live_values = values[:, finalized_block_count, :live_block_token_count] dense_keys = np.concatenate([dense_keys, live_keys], axis=1) dense_values = np.concatenate([dense_values, live_values], axis=1) self.model_kv_cache.ingest_prefill_cache(layer_id, dense_keys, dense_values, trace=trace) self.model_kv_cache.prepare_static_pages(trace=trace) for block_index, block_id in enumerate(resolved_block_ids[:finalized_block_count]): token_start = block_index * self.block_size for kv_head_id in range(self.num_key_value_heads): key_page = encode_page( keys[kv_head_id, block_index], self.config, kind="K", layer_id=layer_id, kv_head_id=kv_head_id, token_start=token_start, mode=self.config.resolve_page_mode(kind="K", layer_id=layer_id, kv_head_id=kv_head_id), build_runtime_metadata=False, ) value_page = encode_page( values[kv_head_id, block_index], self.config, kind="V", layer_id=layer_id, kv_head_id=kv_head_id, token_start=token_start, mode=self.config.resolve_page_mode(kind="V", layer_id=layer_id, kv_head_id=kv_head_id), build_runtime_metadata=False, ) self._blocks[VllmBlockKey(layer_id, kv_head_id, block_id, "K")] = VllmBlockEntry( key=VllmBlockKey(layer_id, kv_head_id, block_id, "K"), page=key_page, finalized=True, token_count=self.block_size, token_start=token_start, ) self._blocks[VllmBlockKey(layer_id, kv_head_id, block_id, "V")] = VllmBlockEntry( key=VllmBlockKey(layer_id, kv_head_id, block_id, "V"), page=value_page, finalized=True, token_count=self.block_size, token_start=token_start, ) if finalized_block_count < block_count: live_block_id = resolved_block_ids[finalized_block_count] live_token_start = finalized_block_count * self.block_size for kv_head_id in range(self.num_key_value_heads): live_key_rows = keys[kv_head_id, finalized_block_count, :live_block_token_count] live_value_rows = values[kv_head_id, finalized_block_count, :live_block_token_count] key_page = encode_page( live_key_rows, self.config, kind="K", layer_id=layer_id, kv_head_id=kv_head_id, token_start=live_token_start, mode="M3", build_runtime_metadata=False, ) value_page = encode_page( live_value_rows, self.config, kind="V", layer_id=layer_id, kv_head_id=kv_head_id, token_start=live_token_start, mode="M3", build_runtime_metadata=False, ) self._blocks[VllmBlockKey(layer_id, kv_head_id, live_block_id, "K")] = VllmBlockEntry( key=VllmBlockKey(layer_id, kv_head_id, live_block_id, "K"), page=key_page, finalized=False, token_count=live_block_token_count, token_start=live_token_start, ) self._blocks[VllmBlockKey(layer_id, kv_head_id, live_block_id, "V")] = VllmBlockEntry( key=VllmBlockKey(layer_id, kv_head_id, live_block_id, "V"), page=value_page, finalized=False, token_count=live_block_token_count, token_start=live_token_start, ) state = _LiveBlockState( block_id=live_block_id, token_start=live_token_start, key_rows=[np.asarray(row, dtype=np.float32) for row in live_key_rows], value_rows=[np.asarray(row, dtype=np.float32) for row in live_value_rows], ) self._live_states[(layer_id, kv_head_id)] = state def append_step( self, layer_id: int, key_step: np.ndarray, value_step: np.ndarray, token_index: int, *, trace: ExecutionTrace | None = None, ) -> None: keys = _normalize_step_tensor( key_step, num_key_value_heads=self.num_key_value_heads, head_dim=self.config.head_dim, name="key_step", ) values = _normalize_step_tensor( value_step, num_key_value_heads=self.num_key_value_heads, head_dim=self.config.head_dim, name="value_step", ) self.model_kv_cache.append_step(layer_id, keys, values, token_index, trace=trace) self._update_block_entries_from_steps(layer_id, keys, values, token_index) def append_step_torch( self, layer_id: int, key_step, value_step, token_index: int, *, trace: ExecutionTrace | None = None, ) -> None: self.model_kv_cache.append_step_torch(layer_id, key_step, value_step, token_index, trace=trace) keys = key_step.detach().to(dtype=key_step.dtype).cpu().numpy().astype(np.float32, copy=False) values = value_step.detach().to(dtype=value_step.dtype).cpu().numpy().astype(np.float32, copy=False) self._update_block_entries_from_steps(layer_id, keys, values, token_index) def _store_block_entry( self, layer_id: int, kv_head_id: int, block_id: int, kind: str, rows: np.ndarray, *, token_start: int, ) -> None: token_count = int(rows.shape[0]) finalized = token_count == self.block_size page = encode_page( np.asarray(rows, dtype=np.float32), self.config, kind=kind, layer_id=layer_id, kv_head_id=kv_head_id, token_start=token_start, mode=( self.config.resolve_page_mode(kind=kind, layer_id=layer_id, kv_head_id=kv_head_id) if finalized else "M3" ), build_runtime_metadata=False, ) self._blocks[VllmBlockKey(layer_id, kv_head_id, block_id, kind)] = VllmBlockEntry( key=VllmBlockKey(layer_id, kv_head_id, block_id, kind), page=page, finalized=finalized, token_count=token_count, token_start=token_start, ) def _update_block_entries_from_chunk( self, layer_id: int, key_rows_by_head: np.ndarray, value_rows_by_head: np.ndarray, token_index: int, ) -> None: token_count = int(key_rows_by_head.shape[1]) if token_count == 0: return block_id = int(token_index) // self.block_size block_offset = int(token_index) % self.block_size token_start = block_id * self.block_size if block_offset + token_count > self.block_size: raise ValueError("chunk update cannot span multiple blocks") for kv_head_id in range(self.num_key_value_heads): state = self._live_states.setdefault((layer_id, kv_head_id), _LiveBlockState()) if block_offset == 0 or state.block_id != block_id: state.clear() state.block_id = block_id state.token_start = token_start state.key_rows.extend(np.asarray(row, dtype=np.float32) for row in key_rows_by_head[kv_head_id]) state.value_rows.extend(np.asarray(row, dtype=np.float32) for row in value_rows_by_head[kv_head_id]) key_rows = np.stack(state.key_rows, axis=0).astype(np.float32, copy=False) value_rows = np.stack(state.value_rows, axis=0).astype(np.float32, copy=False) self._store_block_entry( layer_id, kv_head_id, block_id, "K", key_rows, token_start=token_start, ) self._store_block_entry( layer_id, kv_head_id, block_id, "V", value_rows, token_start=token_start, ) if key_rows.shape[0] == self.block_size: state.clear() def append_tokens_torch( self, layer_id: int, key_rows, value_rows, token_positions, *, trace: ExecutionTrace | None = None, ) -> None: positions = token_positions.reshape(-1).detach().cpu().numpy().astype(np.int64, copy=False) if positions.size == 0: return expected = np.arange(int(positions[0]), int(positions[0]) + positions.size, dtype=np.int64) if not np.array_equal(positions, expected): raise ValueError("Phase 6 vLLM adapter requires contiguous batch=1 token positions") keys_by_head = key_rows.transpose(0, 1) values_by_head = value_rows.transpose(0, 1) offset = 0 while offset < positions.size: token_index = int(positions[offset]) block_offset = token_index % self.block_size chunk_size = min(self.block_size - block_offset, positions.size - offset) key_chunk = keys_by_head[:, offset : offset + chunk_size] value_chunk = values_by_head[:, offset : offset + chunk_size] self.model_kv_cache.append_step_torch( layer_id, key_chunk, value_chunk, token_index, trace=trace, ) key_chunk_cpu = key_chunk.detach().cpu().numpy().astype(np.float32, copy=False) value_chunk_cpu = value_chunk.detach().cpu().numpy().astype(np.float32, copy=False) self._update_block_entries_from_chunk(layer_id, key_chunk_cpu, value_chunk_cpu, token_index) offset += chunk_size def _update_block_entries_from_steps( self, layer_id: int, key_rows_by_head: np.ndarray, value_rows_by_head: np.ndarray, token_index: int, ) -> None: token_count = int(key_rows_by_head.shape[1]) for offset in range(token_count): absolute_token = int(token_index) + offset block_id = absolute_token // self.block_size block_offset = absolute_token % self.block_size token_start = block_id * self.block_size for kv_head_id in range(self.num_key_value_heads): state = self._live_states.setdefault((layer_id, kv_head_id), _LiveBlockState()) if block_offset == 0 or state.block_id != block_id: state.clear() state.block_id = block_id state.token_start = token_start state.key_rows.append(np.asarray(key_rows_by_head[kv_head_id, offset], dtype=np.float32)) state.value_rows.append(np.asarray(value_rows_by_head[kv_head_id, offset], dtype=np.float32)) key_rows = np.stack(state.key_rows, axis=0).astype(np.float32, copy=False) value_rows = np.stack(state.value_rows, axis=0).astype(np.float32, copy=False) finalized = key_rows.shape[0] == self.block_size key_mode = self.config.resolve_page_mode(kind="K", layer_id=layer_id, kv_head_id=kv_head_id) if finalized else "M3" value_mode = self.config.resolve_page_mode(kind="V", layer_id=layer_id, kv_head_id=kv_head_id) if finalized else "M3" key_page = encode_page( key_rows, self.config, kind="K", layer_id=layer_id, kv_head_id=kv_head_id, token_start=token_start, mode=key_mode, build_runtime_metadata=False, ) value_page = encode_page( value_rows, self.config, kind="V", layer_id=layer_id, kv_head_id=kv_head_id, token_start=token_start, mode=value_mode, build_runtime_metadata=False, ) self._blocks[VllmBlockKey(layer_id, kv_head_id, block_id, "K")] = VllmBlockEntry( key=VllmBlockKey(layer_id, kv_head_id, block_id, "K"), page=key_page, finalized=finalized, token_count=int(key_rows.shape[0]), token_start=token_start, ) self._blocks[VllmBlockKey(layer_id, kv_head_id, block_id, "V")] = VllmBlockEntry( key=VllmBlockKey(layer_id, kv_head_id, block_id, "V"), page=value_page, finalized=finalized, token_count=int(value_rows.shape[0]), token_start=token_start, ) if finalized: state.clear() def decode_layer( self, layer_id: int, query_step: np.ndarray, q_head_to_kv_head: Sequence[int] | np.ndarray, *, query_scale: float = 1.0, trace: ExecutionTrace | None = None, ) -> np.ndarray: return self.model_kv_cache.decode_layer( layer_id, query_step, q_head_to_kv_head, query_scale=query_scale, trace=trace, ) def decode_layer_torch( self, layer_id: int, query_step, q_head_to_kv_head: Sequence[int] | np.ndarray, *, query_scale: float = 1.0, trace: ExecutionTrace | None = None, ): return self.model_kv_cache.decode_layer_torch( layer_id, query_step, q_head_to_kv_head, query_scale=query_scale, trace=trace, )