File size: 5,340 Bytes
0769ff3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """
ENGRAM Protocol — 256-Token Block Pool Manager
Segments a full KV cache into fixed-size blocks (256 tokens each) that can be:
- Stored independently (one .eng file per block — D7)
- Retrieved individually via EGR (fine-grained cache hits)
- Composed (assemble a context from multiple blocks)
- Evicted independently (LRU per block, not per session)
Design from arXiv:2603.04428 (Persistent Q4 KV Cache, agent-memory paper).
"""
from __future__ import annotations
from dataclasses import dataclass, field
import torch
from kvcos.core.types import BLOCK_SIZE_TOKENS
@dataclass
class KVBlock:
"""A single 256-token block of KV cache data."""
block_index: int
token_start: int
token_end: int # exclusive
keys: torch.Tensor # [n_layers, n_kv_heads, block_len, head_dim]
values: torch.Tensor # [n_layers, n_kv_heads, block_len, head_dim]
@property
def block_len(self) -> int:
return self.token_end - self.token_start
@property
def is_full(self) -> bool:
return self.block_len == BLOCK_SIZE_TOKENS
@property
def n_layers(self) -> int:
return self.keys.shape[0]
@property
def n_kv_heads(self) -> int:
return self.keys.shape[1]
@property
def head_dim(self) -> int:
return self.keys.shape[3]
@dataclass
class BlockPool:
"""Manages a collection of KV blocks for an agent session."""
agent_id: str
model_id: str
blocks: list[KVBlock] = field(default_factory=list)
@property
def total_tokens(self) -> int:
return sum(b.block_len for b in self.blocks)
@property
def n_blocks(self) -> int:
return len(self.blocks)
def segment(
self, keys: torch.Tensor, values: torch.Tensor,
) -> list[KVBlock]:
"""Segment a full KV cache into 256-token blocks.
Args:
keys: [n_layers, n_kv_heads, ctx_len, head_dim]
values: [n_layers, n_kv_heads, ctx_len, head_dim]
"""
if keys.shape != values.shape:
raise ValueError(f"Shape mismatch: keys {keys.shape} vs values {values.shape}")
ctx_len = keys.shape[2]
blocks: list[KVBlock] = []
for i in range(0, ctx_len, BLOCK_SIZE_TOKENS):
end = min(i + BLOCK_SIZE_TOKENS, ctx_len)
block = KVBlock(
block_index=len(blocks),
token_start=i,
token_end=end,
keys=keys[:, :, i:end, :].contiguous(),
values=values[:, :, i:end, :].contiguous(),
)
blocks.append(block)
self.blocks = blocks
return blocks
def assemble(
self, block_indices: list[int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Assemble KV cache from blocks (concatenate along ctx_len dim)."""
if not self.blocks:
raise ValueError("No blocks to assemble")
selected = self.blocks if block_indices is None else [self.blocks[i] for i in block_indices]
if not selected:
raise ValueError("No blocks selected for assembly")
keys = torch.cat([b.keys for b in selected], dim=2)
values = torch.cat([b.values for b in selected], dim=2)
return keys, values
def append_block(self, block: KVBlock) -> None:
block.block_index = len(self.blocks)
self.blocks.append(block)
def get_block(self, index: int) -> KVBlock:
if index < 0 or index >= len(self.blocks):
raise IndexError(f"Block index {index} out of range [0, {len(self.blocks)})")
return self.blocks[index]
def extend(
self, new_keys: torch.Tensor, new_values: torch.Tensor,
) -> list[KVBlock]:
"""Extend the pool with additional tokens, filling last block first."""
new_ctx_len = new_keys.shape[2]
modified_blocks: list[KVBlock] = []
offset = 0
if self.blocks and not self.blocks[-1].is_full:
last = self.blocks[-1]
space = BLOCK_SIZE_TOKENS - last.block_len
fill = min(space, new_ctx_len)
merged_k = torch.cat([last.keys, new_keys[:, :, :fill, :]], dim=2).contiguous()
merged_v = torch.cat([last.values, new_values[:, :, :fill, :]], dim=2).contiguous()
self.blocks[-1] = KVBlock(
block_index=last.block_index,
token_start=last.token_start,
token_end=last.token_start + merged_k.shape[2],
keys=merged_k,
values=merged_v,
)
modified_blocks.append(self.blocks[-1])
offset = fill
remaining = new_ctx_len - offset
if remaining > 0:
token_base = self.blocks[-1].token_end if self.blocks else 0
sub_pool = BlockPool(agent_id=self.agent_id, model_id=self.model_id)
new_blocks = sub_pool.segment(
new_keys[:, :, offset:, :], new_values[:, :, offset:, :],
)
for b in new_blocks:
b.block_index = len(self.blocks)
b.token_start += token_base
b.token_end += token_base
self.blocks.append(b)
modified_blocks.append(b)
return modified_blocks
def clear(self) -> None:
self.blocks.clear()
|