| """Paged KV-cache block manager with hash-based automatic prefix caching. |
| |
| Concepts (matching vLLM / SGLang terminology): |
| |
| Physical block: a fixed-size slot in the KV-cache pool that holds the K and V |
| tensors for ``block_size`` consecutive tokens of one sequence. |
| |
| Block table: per-sequence list of physical block ids that holds the |
| sequence's KV in logical order. Position ``p`` of the |
| sequence lives in physical block ``block_table[p // B]`` at |
| slot ``p % B``. |
| |
| Prefix cache: a content-addressed lookup from |
| hash(prev_block_hash, tuple_of_token_ids_in_block) |
| to a physical block id. When two sequences share a prefix |
| that aligns to a block boundary, the second sequence can |
| point its block_table at the cached blocks instead of |
| recomputing KV, and the scheduler can skip those tokens. |
| |
| The "chained" hash means two prefixes match iff they are identical from |
| position 0 — exactly the property we need for prefix sharing. |
| |
| This manager is allocation-only: it does NOT store the KV tensors. The |
| ModelRunner owns the actual ``[num_blocks, ...]`` tensors and consults the |
| block tables here to know where to write/read KV. |
| """ |
| from __future__ import annotations |
|
|
| from collections import deque |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| from .request import Sequence |
|
|
|
|
| @dataclass |
| class Block: |
| block_id: int |
| ref_count: int = 0 |
| hash_key: Optional[int] = None |
|
|
|
|
| class BlockManager: |
| def __init__( |
| self, |
| num_blocks: int, |
| block_size: int, |
| enable_prefix_caching: bool = True, |
| ) -> None: |
| self.num_blocks = num_blocks |
| self.block_size = block_size |
| self.enable_prefix_caching = enable_prefix_caching |
|
|
| self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] |
| |
| |
| self._free_uncached: deque[int] = deque(range(num_blocks)) |
| self._free_cached: deque[int] = deque() |
| self._cache: dict[int, int] = {} |
|
|
| |
| self.prefix_cache_hits = 0 |
| self.prefix_cache_lookups = 0 |
|
|
| |
|
|
| @property |
| def num_free_blocks(self) -> int: |
| return len(self._free_uncached) + len(self._free_cached) |
|
|
| @property |
| def num_used_blocks(self) -> int: |
| return self.num_blocks - self.num_free_blocks |
|
|
| def snapshot(self) -> dict: |
| """Cheap dict for the event stream / UI.""" |
| return { |
| "num_blocks": self.num_blocks, |
| "block_size": self.block_size, |
| "num_free_blocks": self.num_free_blocks, |
| "num_cached_entries": len(self._cache), |
| "prefix_cache_hits": self.prefix_cache_hits, |
| "prefix_cache_lookups": self.prefix_cache_lookups, |
| "ref_counts": [b.ref_count for b in self.blocks], |
| "hashed": [b.hash_key is not None for b in self.blocks], |
| } |
|
|
| |
|
|
| def _block_hash(self, prev_hash: Optional[int], token_ids: tuple[int, ...]) -> int: |
| |
| |
| return hash((prev_hash, token_ids)) |
|
|
| def _take_free_block(self) -> int: |
| if self._free_uncached: |
| bid = self._free_uncached.popleft() |
| elif self._free_cached: |
| bid = self._free_cached.popleft() |
| |
| blk = self.blocks[bid] |
| if blk.hash_key is not None: |
| self._cache.pop(blk.hash_key, None) |
| blk.hash_key = None |
| else: |
| raise RuntimeError("BlockManager out of free blocks") |
| blk = self.blocks[bid] |
| blk.ref_count = 1 |
| return bid |
|
|
| def _share(self, block_id: int) -> None: |
| blk = self.blocks[block_id] |
| if blk.ref_count == 0: |
| |
| try: |
| self._free_cached.remove(block_id) |
| except ValueError: |
| pass |
| blk.ref_count += 1 |
|
|
| def _release(self, block_id: int) -> None: |
| blk = self.blocks[block_id] |
| blk.ref_count -= 1 |
| assert blk.ref_count >= 0, f"block {block_id} refcount went negative" |
| if blk.ref_count == 0: |
| if blk.hash_key is not None and self.enable_prefix_caching: |
| self._free_cached.append(block_id) |
| else: |
| self._free_uncached.append(block_id) |
|
|
| def _register(self, block_id: int, hash_key: int) -> None: |
| if not self.enable_prefix_caching: |
| return |
| if hash_key in self._cache: |
| |
| |
| |
| return |
| self.blocks[block_id].hash_key = hash_key |
| self._cache[hash_key] = block_id |
|
|
| |
|
|
| def num_blocks_needed_for(self, num_tokens: int) -> int: |
| return (num_tokens + self.block_size - 1) // self.block_size |
|
|
| def can_allocate_initial(self, seq: Sequence) -> tuple[bool, int]: |
| """Worst-case allocation check for the prompt of `seq`, ignoring prefix |
| cache hits. Returns (ok, num_new_blocks_needed).""" |
| need = self.num_blocks_needed_for(seq.prompt_len) |
| return self.num_free_blocks >= need, need |
|
|
| def admit(self, seq: Sequence) -> None: |
| """Set up `seq` in the cache. |
| |
| Walks the prompt block-by-block. For each full block of *prompt* tokens |
| we already know, check the prefix cache: hit → share; miss → allocate |
| fresh and register the hash now (we know the tokens already). |
| |
| The trailing partial block (if any) is allocated fresh and left |
| un-hashed; it will be hashed by `finalize_step` once it fills up. |
| """ |
| assert not seq.block_table, "admit called on an already-admitted sequence" |
| prev_hash: Optional[int] = None |
| cached_tokens = 0 |
| prompt = seq.prompt_token_ids |
| B = self.block_size |
| num_full = seq.prompt_len // B |
|
|
| |
| |
| |
| |
| cap_full = num_full |
| if seq.prompt_len % B == 0: |
| cap_full = max(0, num_full - 1) |
|
|
| for i in range(num_full): |
| tokens = tuple(prompt[i * B : (i + 1) * B]) |
| h = self._block_hash(prev_hash, tokens) |
| self.prefix_cache_lookups += 1 |
| if self.enable_prefix_caching and h in self._cache and i < cap_full: |
| |
| self.prefix_cache_hits += 1 |
| bid = self._cache[h] |
| self._share(bid) |
| seq.block_table.append(bid) |
| cached_tokens += B |
| prev_hash = h |
| else: |
| |
| |
| |
| bid = self._take_free_block() |
| self._register(bid, h) |
| seq.block_table.append(bid) |
| prev_hash = h |
|
|
| |
| if seq.prompt_len % B != 0: |
| bid = self._take_free_block() |
| seq.block_table.append(bid) |
|
|
| seq.num_computed_tokens = cached_tokens |
| seq.num_cached_prefix_tokens = cached_tokens |
|
|
| def append_slot(self, seq: Sequence) -> Optional[int]: |
| """Ensure `seq` has a slot for one more token (decode path). |
| |
| Returns the block_id that was newly allocated, or None if existing |
| capacity already covered the new token. Raises if no block available. |
| """ |
| new_position = seq.total_len |
| needed_blocks = self.num_blocks_needed_for(new_position + 1) |
| if needed_blocks <= len(seq.block_table): |
| return None |
| if self.num_free_blocks == 0: |
| raise RuntimeError("out of blocks") |
| bid = self._take_free_block() |
| seq.block_table.append(bid) |
| return bid |
|
|
| def ensure_blocks_for_chunk(self, seq: Sequence, chunk_tokens: int) -> int: |
| """Prefill path: make sure `seq.block_table` covers |
| `seq.num_computed_tokens + chunk_tokens` tokens. |
| |
| Returns number of newly-allocated blocks. |
| """ |
| target = seq.num_computed_tokens + chunk_tokens |
| needed = self.num_blocks_needed_for(target) |
| new_alloc = 0 |
| while len(seq.block_table) < needed: |
| bid = self._take_free_block() |
| seq.block_table.append(bid) |
| new_alloc += 1 |
| return new_alloc |
|
|
| def free(self, seq: Sequence) -> None: |
| for bid in seq.block_table: |
| self._release(bid) |
| seq.block_table.clear() |
|
|
| |
|
|
| def register_filled_blocks(self, seq: Sequence, prev_computed: int) -> None: |
| """After a forward pass, hash & register any blocks that just became |
| full so future requests can prefix-cache them.""" |
| if not self.enable_prefix_caching: |
| return |
| B = self.block_size |
| |
| prev_hash: Optional[int] = None |
| for i in range(seq.num_computed_tokens // B): |
| bid = seq.block_table[i] |
| blk = self.blocks[bid] |
| if blk.hash_key is not None: |
| prev_hash = blk.hash_key |
| continue |
| |
| if (i + 1) * B > seq.num_computed_tokens: |
| break |
| tokens = tuple(seq.get_token(i * B + j) for j in range(B)) |
| h = self._block_hash(prev_hash, tokens) |
| self._register(bid, h) |
| prev_hash = h |
|
|