engram / tests /test_block_pool.py
eigengram's picture
test: upload 220 tests
2ece486 verified
"""
ENGRAM Protocol — Block Pool Tests
Tests for 256-token block segmentation/assembly/extend.
"""
from __future__ import annotations
import pytest
import torch
from kvcos.core.block_pool import BlockPool, KVBlock
from kvcos.core.types import BLOCK_SIZE_TOKENS
def _kv(n_layers: int, n_heads: int, ctx: int, dim: int) -> tuple[torch.Tensor, torch.Tensor]:
k = torch.randn(n_layers, n_heads, ctx, dim, dtype=torch.float16)
return k, k.clone()
class TestSegment:
"""Segment full KV cache into 256-token blocks."""
def test_exact_blocks(self) -> None:
keys, vals = _kv(32, 8, 512, 128)
pool = BlockPool(agent_id="a", model_id="m")
blocks = pool.segment(keys, vals)
assert len(blocks) == 2
assert all(b.is_full for b in blocks)
def test_partial_last_block(self) -> None:
keys, vals = _kv(32, 8, 300, 128)
pool = BlockPool(agent_id="a", model_id="m")
blocks = pool.segment(keys, vals)
assert len(blocks) == 2
assert blocks[0].is_full
assert not blocks[1].is_full
assert blocks[1].block_len == 44
def test_total_tokens(self) -> None:
keys, vals = _kv(32, 8, 700, 128)
pool = BlockPool(agent_id="a", model_id="m")
pool.segment(keys, vals)
assert pool.total_tokens == 700
class TestAssemble:
"""Assemble blocks back into full KV cache."""
def test_round_trip(self) -> None:
keys, vals = _kv(4, 2, 512, 64)
pool = BlockPool(agent_id="a", model_id="m")
pool.segment(keys, vals)
k_out, v_out = pool.assemble()
assert torch.equal(k_out, keys)
def test_subset_assembly(self) -> None:
keys, vals = _kv(4, 2, 768, 64)
pool = BlockPool(agent_id="a", model_id="m")
pool.segment(keys, vals)
k_out, _ = pool.assemble(block_indices=[0, 2])
assert k_out.shape[2] == BLOCK_SIZE_TOKENS * 2
def test_empty_raises(self) -> None:
pool = BlockPool(agent_id="a", model_id="m")
with pytest.raises(ValueError, match="No blocks"):
pool.assemble()
class TestExtend:
"""Extend pool with new tokens."""
def test_fills_partial_block(self) -> None:
keys, vals = _kv(4, 2, 200, 64)
pool = BlockPool(agent_id="a", model_id="m")
pool.segment(keys, vals)
assert not pool.blocks[-1].is_full
new_k, new_v = _kv(4, 2, 56, 64)
pool.extend(new_k, new_v)
assert pool.blocks[-1].is_full
assert pool.total_tokens == 256
def test_extend_creates_new_blocks(self) -> None:
keys, vals = _kv(4, 2, 256, 64)
pool = BlockPool(agent_id="a", model_id="m")
pool.segment(keys, vals)
assert pool.n_blocks == 1
new_k, new_v = _kv(4, 2, 300, 64)
pool.extend(new_k, new_v)
assert pool.n_blocks == 3
assert pool.total_tokens == 556