File size: 10,748 Bytes
c32c359 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 | """Paged KV-cache block manager with hash-based automatic prefix caching.
Concepts (matching vLLM / SGLang terminology):
Physical block: a fixed-size slot in the KV-cache pool that holds the K and V
tensors for ``block_size`` consecutive tokens of one sequence.
Block table: per-sequence list of physical block ids that holds the
sequence's KV in logical order. Position ``p`` of the
sequence lives in physical block ``block_table[p // B]`` at
slot ``p % B``.
Prefix cache: a content-addressed lookup from
hash(prev_block_hash, tuple_of_token_ids_in_block)
to a physical block id. When two sequences share a prefix
that aligns to a block boundary, the second sequence can
point its block_table at the cached blocks instead of
recomputing KV, and the scheduler can skip those tokens.
The "chained" hash means two prefixes match iff they are identical from
position 0 — exactly the property we need for prefix sharing.
This manager is allocation-only: it does NOT store the KV tensors. The
ModelRunner owns the actual ``[num_blocks, ...]`` tensors and consults the
block tables here to know where to write/read KV.
"""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from typing import Optional
from .request import Sequence
@dataclass
class Block:
block_id: int
ref_count: int = 0
hash_key: Optional[int] = None # set when the block is full and registered
class BlockManager:
def __init__(
self,
num_blocks: int,
block_size: int,
enable_prefix_caching: bool = True,
) -> None:
self.num_blocks = num_blocks
self.block_size = block_size
self.enable_prefix_caching = enable_prefix_caching
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
# Two-tier free list: ephemeral (no hash) reused first, then cached
# (preserved as long as we have ephemeral capacity).
self._free_uncached: deque[int] = deque(range(num_blocks))
self._free_cached: deque[int] = deque()
self._cache: dict[int, int] = {} # hash → block_id
# Stats (visible via events).
self.prefix_cache_hits = 0
self.prefix_cache_lookups = 0
# ---- introspection --------------------------------------------------
@property
def num_free_blocks(self) -> int:
return len(self._free_uncached) + len(self._free_cached)
@property
def num_used_blocks(self) -> int:
return self.num_blocks - self.num_free_blocks
def snapshot(self) -> dict:
"""Cheap dict for the event stream / UI."""
return {
"num_blocks": self.num_blocks,
"block_size": self.block_size,
"num_free_blocks": self.num_free_blocks,
"num_cached_entries": len(self._cache),
"prefix_cache_hits": self.prefix_cache_hits,
"prefix_cache_lookups": self.prefix_cache_lookups,
"ref_counts": [b.ref_count for b in self.blocks],
"hashed": [b.hash_key is not None for b in self.blocks],
}
# ---- low-level pool ops --------------------------------------------
def _block_hash(self, prev_hash: Optional[int], token_ids: tuple[int, ...]) -> int:
# Python's hash() is randomized per process but that's fine: the cache
# only lives for the engine's lifetime.
return hash((prev_hash, token_ids))
def _take_free_block(self) -> int:
if self._free_uncached:
bid = self._free_uncached.popleft()
elif self._free_cached:
bid = self._free_cached.popleft()
# Evict its cache entry — we're about to repurpose it.
blk = self.blocks[bid]
if blk.hash_key is not None:
self._cache.pop(blk.hash_key, None)
blk.hash_key = None
else:
raise RuntimeError("BlockManager out of free blocks")
blk = self.blocks[bid]
blk.ref_count = 1
return bid
def _share(self, block_id: int) -> None:
blk = self.blocks[block_id]
if blk.ref_count == 0:
# Was sitting in the cached free list; pull it out.
try:
self._free_cached.remove(block_id)
except ValueError:
pass
blk.ref_count += 1
def _release(self, block_id: int) -> None:
blk = self.blocks[block_id]
blk.ref_count -= 1
assert blk.ref_count >= 0, f"block {block_id} refcount went negative"
if blk.ref_count == 0:
if blk.hash_key is not None and self.enable_prefix_caching:
self._free_cached.append(block_id)
else:
self._free_uncached.append(block_id)
def _register(self, block_id: int, hash_key: int) -> None:
if not self.enable_prefix_caching:
return
if hash_key in self._cache:
# Two sequences independently produced the same content for
# different physical blocks. Keep the older one; this one becomes
# ephemeral so it gets reclaimed first.
return
self.blocks[block_id].hash_key = hash_key
self._cache[hash_key] = block_id
# ---- per-sequence allocation ---------------------------------------
def num_blocks_needed_for(self, num_tokens: int) -> int:
return (num_tokens + self.block_size - 1) // self.block_size
def can_allocate_initial(self, seq: Sequence) -> tuple[bool, int]:
"""Worst-case allocation check for the prompt of `seq`, ignoring prefix
cache hits. Returns (ok, num_new_blocks_needed)."""
need = self.num_blocks_needed_for(seq.prompt_len)
return self.num_free_blocks >= need, need
def admit(self, seq: Sequence) -> None:
"""Set up `seq` in the cache.
Walks the prompt block-by-block. For each full block of *prompt* tokens
we already know, check the prefix cache: hit → share; miss → allocate
fresh and register the hash now (we know the tokens already).
The trailing partial block (if any) is allocated fresh and left
un-hashed; it will be hashed by `finalize_step` once it fills up.
"""
assert not seq.block_table, "admit called on an already-admitted sequence"
prev_hash: Optional[int] = None
cached_tokens = 0
prompt = seq.prompt_token_ids
B = self.block_size
num_full = seq.prompt_len // B
# IMPORTANT: never let prefix cache cover the entire prompt — we need
# at least one token to forward through the model to get logits for
# the first sampled token. If the full prompt block-aligns AND every
# block is cached, drop the last cached block.
cap_full = num_full
if seq.prompt_len % B == 0:
cap_full = max(0, num_full - 1)
for i in range(num_full):
tokens = tuple(prompt[i * B : (i + 1) * B])
h = self._block_hash(prev_hash, tokens)
self.prefix_cache_lookups += 1
if self.enable_prefix_caching and h in self._cache and i < cap_full:
# Cache hit.
self.prefix_cache_hits += 1
bid = self._cache[h]
self._share(bid)
seq.block_table.append(bid)
cached_tokens += B
prev_hash = h
else:
# Miss: allocate, and since the block content is fully known
# (prompt tokens), register its hash right away so the next
# request with this prefix can hit.
bid = self._take_free_block()
self._register(bid, h)
seq.block_table.append(bid)
prev_hash = h
# Trailing partial block, if any.
if seq.prompt_len % B != 0:
bid = self._take_free_block()
seq.block_table.append(bid)
seq.num_computed_tokens = cached_tokens
seq.num_cached_prefix_tokens = cached_tokens
def append_slot(self, seq: Sequence) -> Optional[int]:
"""Ensure `seq` has a slot for one more token (decode path).
Returns the block_id that was newly allocated, or None if existing
capacity already covered the new token. Raises if no block available.
"""
new_position = seq.total_len # 0-indexed slot we are about to write
needed_blocks = self.num_blocks_needed_for(new_position + 1)
if needed_blocks <= len(seq.block_table):
return None
if self.num_free_blocks == 0:
raise RuntimeError("out of blocks")
bid = self._take_free_block()
seq.block_table.append(bid)
return bid
def ensure_blocks_for_chunk(self, seq: Sequence, chunk_tokens: int) -> int:
"""Prefill path: make sure `seq.block_table` covers
`seq.num_computed_tokens + chunk_tokens` tokens.
Returns number of newly-allocated blocks.
"""
target = seq.num_computed_tokens + chunk_tokens
needed = self.num_blocks_needed_for(target)
new_alloc = 0
while len(seq.block_table) < needed:
bid = self._take_free_block()
seq.block_table.append(bid)
new_alloc += 1
return new_alloc
def free(self, seq: Sequence) -> None:
for bid in seq.block_table:
self._release(bid)
seq.block_table.clear()
# ---- post-step bookkeeping -----------------------------------------
def register_filled_blocks(self, seq: Sequence, prev_computed: int) -> None:
"""After a forward pass, hash & register any blocks that just became
full so future requests can prefix-cache them."""
if not self.enable_prefix_caching:
return
B = self.block_size
# Re-chain hashes from the start so we always have prev_hash correct.
prev_hash: Optional[int] = None
for i in range(seq.num_computed_tokens // B):
bid = seq.block_table[i]
blk = self.blocks[bid]
if blk.hash_key is not None:
prev_hash = blk.hash_key
continue
# This block became full in this step (or earlier but unhashed).
if (i + 1) * B > seq.num_computed_tokens:
break # not actually full yet — defensive
tokens = tuple(seq.get_token(i * B + j) for j in range(B))
h = self._block_hash(prev_hash, tokens)
self._register(bid, h)
prev_hash = h
|