tiny-vllm / tiny_vllm /block_manager.py
enCoder's picture
minimal continuous-batching LLM engine
c32c359
"""Paged KV-cache block manager with hash-based automatic prefix caching.
Concepts (matching vLLM / SGLang terminology):
Physical block: a fixed-size slot in the KV-cache pool that holds the K and V
tensors for ``block_size`` consecutive tokens of one sequence.
Block table: per-sequence list of physical block ids that holds the
sequence's KV in logical order. Position ``p`` of the
sequence lives in physical block ``block_table[p // B]`` at
slot ``p % B``.
Prefix cache: a content-addressed lookup from
hash(prev_block_hash, tuple_of_token_ids_in_block)
to a physical block id. When two sequences share a prefix
that aligns to a block boundary, the second sequence can
point its block_table at the cached blocks instead of
recomputing KV, and the scheduler can skip those tokens.
The "chained" hash means two prefixes match iff they are identical from
position 0 — exactly the property we need for prefix sharing.
This manager is allocation-only: it does NOT store the KV tensors. The
ModelRunner owns the actual ``[num_blocks, ...]`` tensors and consults the
block tables here to know where to write/read KV.
"""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from typing import Optional
from .request import Sequence
@dataclass
class Block:
block_id: int
ref_count: int = 0
hash_key: Optional[int] = None # set when the block is full and registered
class BlockManager:
def __init__(
self,
num_blocks: int,
block_size: int,
enable_prefix_caching: bool = True,
) -> None:
self.num_blocks = num_blocks
self.block_size = block_size
self.enable_prefix_caching = enable_prefix_caching
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
# Two-tier free list: ephemeral (no hash) reused first, then cached
# (preserved as long as we have ephemeral capacity).
self._free_uncached: deque[int] = deque(range(num_blocks))
self._free_cached: deque[int] = deque()
self._cache: dict[int, int] = {} # hash → block_id
# Stats (visible via events).
self.prefix_cache_hits = 0
self.prefix_cache_lookups = 0
# ---- introspection --------------------------------------------------
@property
def num_free_blocks(self) -> int:
return len(self._free_uncached) + len(self._free_cached)
@property
def num_used_blocks(self) -> int:
return self.num_blocks - self.num_free_blocks
def snapshot(self) -> dict:
"""Cheap dict for the event stream / UI."""
return {
"num_blocks": self.num_blocks,
"block_size": self.block_size,
"num_free_blocks": self.num_free_blocks,
"num_cached_entries": len(self._cache),
"prefix_cache_hits": self.prefix_cache_hits,
"prefix_cache_lookups": self.prefix_cache_lookups,
"ref_counts": [b.ref_count for b in self.blocks],
"hashed": [b.hash_key is not None for b in self.blocks],
}
# ---- low-level pool ops --------------------------------------------
def _block_hash(self, prev_hash: Optional[int], token_ids: tuple[int, ...]) -> int:
# Python's hash() is randomized per process but that's fine: the cache
# only lives for the engine's lifetime.
return hash((prev_hash, token_ids))
def _take_free_block(self) -> int:
if self._free_uncached:
bid = self._free_uncached.popleft()
elif self._free_cached:
bid = self._free_cached.popleft()
# Evict its cache entry — we're about to repurpose it.
blk = self.blocks[bid]
if blk.hash_key is not None:
self._cache.pop(blk.hash_key, None)
blk.hash_key = None
else:
raise RuntimeError("BlockManager out of free blocks")
blk = self.blocks[bid]
blk.ref_count = 1
return bid
def _share(self, block_id: int) -> None:
blk = self.blocks[block_id]
if blk.ref_count == 0:
# Was sitting in the cached free list; pull it out.
try:
self._free_cached.remove(block_id)
except ValueError:
pass
blk.ref_count += 1
def _release(self, block_id: int) -> None:
blk = self.blocks[block_id]
blk.ref_count -= 1
assert blk.ref_count >= 0, f"block {block_id} refcount went negative"
if blk.ref_count == 0:
if blk.hash_key is not None and self.enable_prefix_caching:
self._free_cached.append(block_id)
else:
self._free_uncached.append(block_id)
def _register(self, block_id: int, hash_key: int) -> None:
if not self.enable_prefix_caching:
return
if hash_key in self._cache:
# Two sequences independently produced the same content for
# different physical blocks. Keep the older one; this one becomes
# ephemeral so it gets reclaimed first.
return
self.blocks[block_id].hash_key = hash_key
self._cache[hash_key] = block_id
# ---- per-sequence allocation ---------------------------------------
def num_blocks_needed_for(self, num_tokens: int) -> int:
return (num_tokens + self.block_size - 1) // self.block_size
def can_allocate_initial(self, seq: Sequence) -> tuple[bool, int]:
"""Worst-case allocation check for the prompt of `seq`, ignoring prefix
cache hits. Returns (ok, num_new_blocks_needed)."""
need = self.num_blocks_needed_for(seq.prompt_len)
return self.num_free_blocks >= need, need
def admit(self, seq: Sequence) -> None:
"""Set up `seq` in the cache.
Walks the prompt block-by-block. For each full block of *prompt* tokens
we already know, check the prefix cache: hit → share; miss → allocate
fresh and register the hash now (we know the tokens already).
The trailing partial block (if any) is allocated fresh and left
un-hashed; it will be hashed by `finalize_step` once it fills up.
"""
assert not seq.block_table, "admit called on an already-admitted sequence"
prev_hash: Optional[int] = None
cached_tokens = 0
prompt = seq.prompt_token_ids
B = self.block_size
num_full = seq.prompt_len // B
# IMPORTANT: never let prefix cache cover the entire prompt — we need
# at least one token to forward through the model to get logits for
# the first sampled token. If the full prompt block-aligns AND every
# block is cached, drop the last cached block.
cap_full = num_full
if seq.prompt_len % B == 0:
cap_full = max(0, num_full - 1)
for i in range(num_full):
tokens = tuple(prompt[i * B : (i + 1) * B])
h = self._block_hash(prev_hash, tokens)
self.prefix_cache_lookups += 1
if self.enable_prefix_caching and h in self._cache and i < cap_full:
# Cache hit.
self.prefix_cache_hits += 1
bid = self._cache[h]
self._share(bid)
seq.block_table.append(bid)
cached_tokens += B
prev_hash = h
else:
# Miss: allocate, and since the block content is fully known
# (prompt tokens), register its hash right away so the next
# request with this prefix can hit.
bid = self._take_free_block()
self._register(bid, h)
seq.block_table.append(bid)
prev_hash = h
# Trailing partial block, if any.
if seq.prompt_len % B != 0:
bid = self._take_free_block()
seq.block_table.append(bid)
seq.num_computed_tokens = cached_tokens
seq.num_cached_prefix_tokens = cached_tokens
def append_slot(self, seq: Sequence) -> Optional[int]:
"""Ensure `seq` has a slot for one more token (decode path).
Returns the block_id that was newly allocated, or None if existing
capacity already covered the new token. Raises if no block available.
"""
new_position = seq.total_len # 0-indexed slot we are about to write
needed_blocks = self.num_blocks_needed_for(new_position + 1)
if needed_blocks <= len(seq.block_table):
return None
if self.num_free_blocks == 0:
raise RuntimeError("out of blocks")
bid = self._take_free_block()
seq.block_table.append(bid)
return bid
def ensure_blocks_for_chunk(self, seq: Sequence, chunk_tokens: int) -> int:
"""Prefill path: make sure `seq.block_table` covers
`seq.num_computed_tokens + chunk_tokens` tokens.
Returns number of newly-allocated blocks.
"""
target = seq.num_computed_tokens + chunk_tokens
needed = self.num_blocks_needed_for(target)
new_alloc = 0
while len(seq.block_table) < needed:
bid = self._take_free_block()
seq.block_table.append(bid)
new_alloc += 1
return new_alloc
def free(self, seq: Sequence) -> None:
for bid in seq.block_table:
self._release(bid)
seq.block_table.clear()
# ---- post-step bookkeeping -----------------------------------------
def register_filled_blocks(self, seq: Sequence, prev_computed: int) -> None:
"""After a forward pass, hash & register any blocks that just became
full so future requests can prefix-cache them."""
if not self.enable_prefix_caching:
return
B = self.block_size
# Re-chain hashes from the start so we always have prev_hash correct.
prev_hash: Optional[int] = None
for i in range(seq.num_computed_tokens // B):
bid = seq.block_table[i]
blk = self.blocks[bid]
if blk.hash_key is not None:
prev_hash = blk.hash_key
continue
# This block became full in this step (or earlier but unhashed).
if (i + 1) * B > seq.num_computed_tokens:
break # not actually full yet — defensive
tokens = tuple(seq.get_token(i * B + j) for j in range(B))
h = self._block_hash(prev_hash, tokens)
self._register(bid, h)
prev_hash = h