File size: 2,808 Bytes
c32c359 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | """The actual KV tensor pool that the BlockManager indexes into.
We store one ``[num_blocks, block_size, num_kv_heads, head_dim]`` tensor per
layer for K and V. The block_manager owns the *allocation* of block ids; this
class owns the *bytes*. Reads and writes happen by (block_id, offset).
"""
from __future__ import annotations
import torch
class PagedKVCache:
def __init__(
self,
num_layers: int,
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
self.num_layers = num_layers
self.num_blocks = num_blocks
self.block_size = block_size
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
self.device = device
shape = (num_blocks, block_size, num_kv_heads, head_dim)
self.k_cache = [torch.zeros(shape, dtype=dtype, device=device) for _ in range(num_layers)]
self.v_cache = [torch.zeros(shape, dtype=dtype, device=device) for _ in range(num_layers)]
def write(
self,
layer_id: int,
k: torch.Tensor, # [T, num_kv_heads, head_dim]
v: torch.Tensor, # [T, num_kv_heads, head_dim]
slot_mapping: torch.Tensor # [T] int64, slot_id = block_id*block_size + offset
) -> None:
block_ids = (slot_mapping // self.block_size).long()
offsets = (slot_mapping % self.block_size).long()
self.k_cache[layer_id][block_ids, offsets] = k.to(self.dtype)
self.v_cache[layer_id][block_ids, offsets] = v.to(self.dtype)
def gather(
self,
layer_id: int,
block_table: list[int],
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return contiguous [num_tokens, num_kv_heads, head_dim] K and V
for one sequence, by walking its block table."""
if num_tokens == 0:
empty = torch.zeros(
0, self.num_kv_heads, self.head_dim,
dtype=self.dtype, device=self.device,
)
return empty, empty.clone()
num_full = num_tokens // self.block_size
tail = num_tokens % self.block_size
idxs = block_table[:num_full + (1 if tail else 0)]
idx_tensor = torch.as_tensor(idxs, dtype=torch.long, device=self.device)
# [P, block_size, H, D]
k_blocks = self.k_cache[layer_id].index_select(0, idx_tensor)
v_blocks = self.v_cache[layer_id].index_select(0, idx_tensor)
# Flatten the first two dims then trim.
k_flat = k_blocks.reshape(-1, self.num_kv_heads, self.head_dim)
v_flat = v_blocks.reshape(-1, self.num_kv_heads, self.head_dim)
return k_flat[:num_tokens], v_flat[:num_tokens]
|