| """ |
| ENGRAM Protocol — Block Pool Tests |
| Tests for 256-token block segmentation/assembly/extend. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import pytest |
| import torch |
|
|
| from kvcos.core.block_pool import BlockPool, KVBlock |
| from kvcos.core.types import BLOCK_SIZE_TOKENS |
|
|
|
|
| def _kv(n_layers: int, n_heads: int, ctx: int, dim: int) -> tuple[torch.Tensor, torch.Tensor]: |
| k = torch.randn(n_layers, n_heads, ctx, dim, dtype=torch.float16) |
| return k, k.clone() |
|
|
|
|
| class TestSegment: |
| """Segment full KV cache into 256-token blocks.""" |
|
|
| def test_exact_blocks(self) -> None: |
| keys, vals = _kv(32, 8, 512, 128) |
| pool = BlockPool(agent_id="a", model_id="m") |
| blocks = pool.segment(keys, vals) |
| assert len(blocks) == 2 |
| assert all(b.is_full for b in blocks) |
|
|
| def test_partial_last_block(self) -> None: |
| keys, vals = _kv(32, 8, 300, 128) |
| pool = BlockPool(agent_id="a", model_id="m") |
| blocks = pool.segment(keys, vals) |
| assert len(blocks) == 2 |
| assert blocks[0].is_full |
| assert not blocks[1].is_full |
| assert blocks[1].block_len == 44 |
|
|
| def test_total_tokens(self) -> None: |
| keys, vals = _kv(32, 8, 700, 128) |
| pool = BlockPool(agent_id="a", model_id="m") |
| pool.segment(keys, vals) |
| assert pool.total_tokens == 700 |
|
|
|
|
| class TestAssemble: |
| """Assemble blocks back into full KV cache.""" |
|
|
| def test_round_trip(self) -> None: |
| keys, vals = _kv(4, 2, 512, 64) |
| pool = BlockPool(agent_id="a", model_id="m") |
| pool.segment(keys, vals) |
| k_out, v_out = pool.assemble() |
| assert torch.equal(k_out, keys) |
|
|
| def test_subset_assembly(self) -> None: |
| keys, vals = _kv(4, 2, 768, 64) |
| pool = BlockPool(agent_id="a", model_id="m") |
| pool.segment(keys, vals) |
| k_out, _ = pool.assemble(block_indices=[0, 2]) |
| assert k_out.shape[2] == BLOCK_SIZE_TOKENS * 2 |
|
|
| def test_empty_raises(self) -> None: |
| pool = BlockPool(agent_id="a", model_id="m") |
| with pytest.raises(ValueError, match="No blocks"): |
| pool.assemble() |
|
|
|
|
| class TestExtend: |
| """Extend pool with new tokens.""" |
|
|
| def test_fills_partial_block(self) -> None: |
| keys, vals = _kv(4, 2, 200, 64) |
| pool = BlockPool(agent_id="a", model_id="m") |
| pool.segment(keys, vals) |
| assert not pool.blocks[-1].is_full |
|
|
| new_k, new_v = _kv(4, 2, 56, 64) |
| pool.extend(new_k, new_v) |
| assert pool.blocks[-1].is_full |
| assert pool.total_tokens == 256 |
|
|
| def test_extend_creates_new_blocks(self) -> None: |
| keys, vals = _kv(4, 2, 256, 64) |
| pool = BlockPool(agent_id="a", model_id="m") |
| pool.segment(keys, vals) |
| assert pool.n_blocks == 1 |
|
|
| new_k, new_v = _kv(4, 2, 300, 64) |
| pool.extend(new_k, new_v) |
| assert pool.n_blocks == 3 |
| assert pool.total_tokens == 556 |
|
|