File size: 2,907 Bytes
2ece486 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | """
ENGRAM Protocol — Block Pool Tests
Tests for 256-token block segmentation/assembly/extend.
"""
from __future__ import annotations
import pytest
import torch
from kvcos.core.block_pool import BlockPool, KVBlock
from kvcos.core.types import BLOCK_SIZE_TOKENS
def _kv(n_layers: int, n_heads: int, ctx: int, dim: int) -> tuple[torch.Tensor, torch.Tensor]:
k = torch.randn(n_layers, n_heads, ctx, dim, dtype=torch.float16)
return k, k.clone()
class TestSegment:
"""Segment full KV cache into 256-token blocks."""
def test_exact_blocks(self) -> None:
keys, vals = _kv(32, 8, 512, 128)
pool = BlockPool(agent_id="a", model_id="m")
blocks = pool.segment(keys, vals)
assert len(blocks) == 2
assert all(b.is_full for b in blocks)
def test_partial_last_block(self) -> None:
keys, vals = _kv(32, 8, 300, 128)
pool = BlockPool(agent_id="a", model_id="m")
blocks = pool.segment(keys, vals)
assert len(blocks) == 2
assert blocks[0].is_full
assert not blocks[1].is_full
assert blocks[1].block_len == 44
def test_total_tokens(self) -> None:
keys, vals = _kv(32, 8, 700, 128)
pool = BlockPool(agent_id="a", model_id="m")
pool.segment(keys, vals)
assert pool.total_tokens == 700
class TestAssemble:
"""Assemble blocks back into full KV cache."""
def test_round_trip(self) -> None:
keys, vals = _kv(4, 2, 512, 64)
pool = BlockPool(agent_id="a", model_id="m")
pool.segment(keys, vals)
k_out, v_out = pool.assemble()
assert torch.equal(k_out, keys)
def test_subset_assembly(self) -> None:
keys, vals = _kv(4, 2, 768, 64)
pool = BlockPool(agent_id="a", model_id="m")
pool.segment(keys, vals)
k_out, _ = pool.assemble(block_indices=[0, 2])
assert k_out.shape[2] == BLOCK_SIZE_TOKENS * 2
def test_empty_raises(self) -> None:
pool = BlockPool(agent_id="a", model_id="m")
with pytest.raises(ValueError, match="No blocks"):
pool.assemble()
class TestExtend:
"""Extend pool with new tokens."""
def test_fills_partial_block(self) -> None:
keys, vals = _kv(4, 2, 200, 64)
pool = BlockPool(agent_id="a", model_id="m")
pool.segment(keys, vals)
assert not pool.blocks[-1].is_full
new_k, new_v = _kv(4, 2, 56, 64)
pool.extend(new_k, new_v)
assert pool.blocks[-1].is_full
assert pool.total_tokens == 256
def test_extend_creates_new_blocks(self) -> None:
keys, vals = _kv(4, 2, 256, 64)
pool = BlockPool(agent_id="a", model_id="m")
pool.segment(keys, vals)
assert pool.n_blocks == 1
new_k, new_v = _kv(4, 2, 300, 64)
pool.extend(new_k, new_v)
assert pool.n_blocks == 3
assert pool.total_tokens == 556
|