File size: 10,748 Bytes
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""Paged KV-cache block manager with hash-based automatic prefix caching.

Concepts (matching vLLM / SGLang terminology):

  Physical block:  a fixed-size slot in the KV-cache pool that holds the K and V
                   tensors for ``block_size`` consecutive tokens of one sequence.

  Block table:     per-sequence list of physical block ids that holds the
                   sequence's KV in logical order.  Position ``p`` of the
                   sequence lives in physical block ``block_table[p // B]`` at
                   slot ``p % B``.

  Prefix cache:    a content-addressed lookup from
                       hash(prev_block_hash, tuple_of_token_ids_in_block)
                   to a physical block id.  When two sequences share a prefix
                   that aligns to a block boundary, the second sequence can
                   point its block_table at the cached blocks instead of
                   recomputing KV, and the scheduler can skip those tokens.

The "chained" hash means two prefixes match iff they are identical from
position 0 — exactly the property we need for prefix sharing.

This manager is allocation-only: it does NOT store the KV tensors.  The
ModelRunner owns the actual ``[num_blocks, ...]`` tensors and consults the
block tables here to know where to write/read KV.
"""
from __future__ import annotations

from collections import deque
from dataclasses import dataclass
from typing import Optional

from .request import Sequence


@dataclass
class Block:
    block_id: int
    ref_count: int = 0
    hash_key: Optional[int] = None  # set when the block is full and registered


class BlockManager:
    def __init__(
        self,
        num_blocks: int,
        block_size: int,
        enable_prefix_caching: bool = True,
    ) -> None:
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.enable_prefix_caching = enable_prefix_caching

        self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
        # Two-tier free list: ephemeral (no hash) reused first, then cached
        # (preserved as long as we have ephemeral capacity).
        self._free_uncached: deque[int] = deque(range(num_blocks))
        self._free_cached: deque[int] = deque()
        self._cache: dict[int, int] = {}  # hash → block_id

        # Stats (visible via events).
        self.prefix_cache_hits = 0
        self.prefix_cache_lookups = 0

    # ---- introspection --------------------------------------------------

    @property
    def num_free_blocks(self) -> int:
        return len(self._free_uncached) + len(self._free_cached)

    @property
    def num_used_blocks(self) -> int:
        return self.num_blocks - self.num_free_blocks

    def snapshot(self) -> dict:
        """Cheap dict for the event stream / UI."""
        return {
            "num_blocks": self.num_blocks,
            "block_size": self.block_size,
            "num_free_blocks": self.num_free_blocks,
            "num_cached_entries": len(self._cache),
            "prefix_cache_hits": self.prefix_cache_hits,
            "prefix_cache_lookups": self.prefix_cache_lookups,
            "ref_counts": [b.ref_count for b in self.blocks],
            "hashed": [b.hash_key is not None for b in self.blocks],
        }

    # ---- low-level pool ops --------------------------------------------

    def _block_hash(self, prev_hash: Optional[int], token_ids: tuple[int, ...]) -> int:
        # Python's hash() is randomized per process but that's fine: the cache
        # only lives for the engine's lifetime.
        return hash((prev_hash, token_ids))

    def _take_free_block(self) -> int:
        if self._free_uncached:
            bid = self._free_uncached.popleft()
        elif self._free_cached:
            bid = self._free_cached.popleft()
            # Evict its cache entry — we're about to repurpose it.
            blk = self.blocks[bid]
            if blk.hash_key is not None:
                self._cache.pop(blk.hash_key, None)
                blk.hash_key = None
        else:
            raise RuntimeError("BlockManager out of free blocks")
        blk = self.blocks[bid]
        blk.ref_count = 1
        return bid

    def _share(self, block_id: int) -> None:
        blk = self.blocks[block_id]
        if blk.ref_count == 0:
            # Was sitting in the cached free list; pull it out.
            try:
                self._free_cached.remove(block_id)
            except ValueError:
                pass
        blk.ref_count += 1

    def _release(self, block_id: int) -> None:
        blk = self.blocks[block_id]
        blk.ref_count -= 1
        assert blk.ref_count >= 0, f"block {block_id} refcount went negative"
        if blk.ref_count == 0:
            if blk.hash_key is not None and self.enable_prefix_caching:
                self._free_cached.append(block_id)
            else:
                self._free_uncached.append(block_id)

    def _register(self, block_id: int, hash_key: int) -> None:
        if not self.enable_prefix_caching:
            return
        if hash_key in self._cache:
            # Two sequences independently produced the same content for
            # different physical blocks. Keep the older one; this one becomes
            # ephemeral so it gets reclaimed first.
            return
        self.blocks[block_id].hash_key = hash_key
        self._cache[hash_key] = block_id

    # ---- per-sequence allocation ---------------------------------------

    def num_blocks_needed_for(self, num_tokens: int) -> int:
        return (num_tokens + self.block_size - 1) // self.block_size

    def can_allocate_initial(self, seq: Sequence) -> tuple[bool, int]:
        """Worst-case allocation check for the prompt of `seq`, ignoring prefix
        cache hits.  Returns (ok, num_new_blocks_needed)."""
        need = self.num_blocks_needed_for(seq.prompt_len)
        return self.num_free_blocks >= need, need

    def admit(self, seq: Sequence) -> None:
        """Set up `seq` in the cache.

        Walks the prompt block-by-block.  For each full block of *prompt* tokens
        we already know, check the prefix cache: hit → share; miss → allocate
        fresh and register the hash now (we know the tokens already).

        The trailing partial block (if any) is allocated fresh and left
        un-hashed; it will be hashed by `finalize_step` once it fills up.
        """
        assert not seq.block_table, "admit called on an already-admitted sequence"
        prev_hash: Optional[int] = None
        cached_tokens = 0
        prompt = seq.prompt_token_ids
        B = self.block_size
        num_full = seq.prompt_len // B

        # IMPORTANT: never let prefix cache cover the entire prompt — we need
        # at least one token to forward through the model to get logits for
        # the first sampled token.  If the full prompt block-aligns AND every
        # block is cached, drop the last cached block.
        cap_full = num_full
        if seq.prompt_len % B == 0:
            cap_full = max(0, num_full - 1)

        for i in range(num_full):
            tokens = tuple(prompt[i * B : (i + 1) * B])
            h = self._block_hash(prev_hash, tokens)
            self.prefix_cache_lookups += 1
            if self.enable_prefix_caching and h in self._cache and i < cap_full:
                # Cache hit.
                self.prefix_cache_hits += 1
                bid = self._cache[h]
                self._share(bid)
                seq.block_table.append(bid)
                cached_tokens += B
                prev_hash = h
            else:
                # Miss: allocate, and since the block content is fully known
                # (prompt tokens), register its hash right away so the next
                # request with this prefix can hit.
                bid = self._take_free_block()
                self._register(bid, h)
                seq.block_table.append(bid)
                prev_hash = h

        # Trailing partial block, if any.
        if seq.prompt_len % B != 0:
            bid = self._take_free_block()
            seq.block_table.append(bid)

        seq.num_computed_tokens = cached_tokens
        seq.num_cached_prefix_tokens = cached_tokens

    def append_slot(self, seq: Sequence) -> Optional[int]:
        """Ensure `seq` has a slot for one more token (decode path).

        Returns the block_id that was newly allocated, or None if existing
        capacity already covered the new token.  Raises if no block available.
        """
        new_position = seq.total_len  # 0-indexed slot we are about to write
        needed_blocks = self.num_blocks_needed_for(new_position + 1)
        if needed_blocks <= len(seq.block_table):
            return None
        if self.num_free_blocks == 0:
            raise RuntimeError("out of blocks")
        bid = self._take_free_block()
        seq.block_table.append(bid)
        return bid

    def ensure_blocks_for_chunk(self, seq: Sequence, chunk_tokens: int) -> int:
        """Prefill path: make sure `seq.block_table` covers
        `seq.num_computed_tokens + chunk_tokens` tokens.

        Returns number of newly-allocated blocks.
        """
        target = seq.num_computed_tokens + chunk_tokens
        needed = self.num_blocks_needed_for(target)
        new_alloc = 0
        while len(seq.block_table) < needed:
            bid = self._take_free_block()
            seq.block_table.append(bid)
            new_alloc += 1
        return new_alloc

    def free(self, seq: Sequence) -> None:
        for bid in seq.block_table:
            self._release(bid)
        seq.block_table.clear()

    # ---- post-step bookkeeping -----------------------------------------

    def register_filled_blocks(self, seq: Sequence, prev_computed: int) -> None:
        """After a forward pass, hash & register any blocks that just became
        full so future requests can prefix-cache them."""
        if not self.enable_prefix_caching:
            return
        B = self.block_size
        # Re-chain hashes from the start so we always have prev_hash correct.
        prev_hash: Optional[int] = None
        for i in range(seq.num_computed_tokens // B):
            bid = seq.block_table[i]
            blk = self.blocks[bid]
            if blk.hash_key is not None:
                prev_hash = blk.hash_key
                continue
            # This block became full in this step (or earlier but unhashed).
            if (i + 1) * B > seq.num_computed_tokens:
                break  # not actually full yet — defensive
            tokens = tuple(seq.get_token(i * B + j) for j in range(B))
            h = self._block_hash(prev_hash, tokens)
            self._register(bid, h)
            prev_hash = h