File size: 2,907 Bytes
2ece486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
ENGRAM Protocol — Block Pool Tests
Tests for 256-token block segmentation/assembly/extend.
"""

from __future__ import annotations

import pytest
import torch

from kvcos.core.block_pool import BlockPool, KVBlock
from kvcos.core.types import BLOCK_SIZE_TOKENS


def _kv(n_layers: int, n_heads: int, ctx: int, dim: int) -> tuple[torch.Tensor, torch.Tensor]:
    k = torch.randn(n_layers, n_heads, ctx, dim, dtype=torch.float16)
    return k, k.clone()


class TestSegment:
    """Segment full KV cache into 256-token blocks."""

    def test_exact_blocks(self) -> None:
        keys, vals = _kv(32, 8, 512, 128)
        pool = BlockPool(agent_id="a", model_id="m")
        blocks = pool.segment(keys, vals)
        assert len(blocks) == 2
        assert all(b.is_full for b in blocks)

    def test_partial_last_block(self) -> None:
        keys, vals = _kv(32, 8, 300, 128)
        pool = BlockPool(agent_id="a", model_id="m")
        blocks = pool.segment(keys, vals)
        assert len(blocks) == 2
        assert blocks[0].is_full
        assert not blocks[1].is_full
        assert blocks[1].block_len == 44

    def test_total_tokens(self) -> None:
        keys, vals = _kv(32, 8, 700, 128)
        pool = BlockPool(agent_id="a", model_id="m")
        pool.segment(keys, vals)
        assert pool.total_tokens == 700


class TestAssemble:
    """Assemble blocks back into full KV cache."""

    def test_round_trip(self) -> None:
        keys, vals = _kv(4, 2, 512, 64)
        pool = BlockPool(agent_id="a", model_id="m")
        pool.segment(keys, vals)
        k_out, v_out = pool.assemble()
        assert torch.equal(k_out, keys)

    def test_subset_assembly(self) -> None:
        keys, vals = _kv(4, 2, 768, 64)
        pool = BlockPool(agent_id="a", model_id="m")
        pool.segment(keys, vals)
        k_out, _ = pool.assemble(block_indices=[0, 2])
        assert k_out.shape[2] == BLOCK_SIZE_TOKENS * 2

    def test_empty_raises(self) -> None:
        pool = BlockPool(agent_id="a", model_id="m")
        with pytest.raises(ValueError, match="No blocks"):
            pool.assemble()


class TestExtend:
    """Extend pool with new tokens."""

    def test_fills_partial_block(self) -> None:
        keys, vals = _kv(4, 2, 200, 64)
        pool = BlockPool(agent_id="a", model_id="m")
        pool.segment(keys, vals)
        assert not pool.blocks[-1].is_full

        new_k, new_v = _kv(4, 2, 56, 64)
        pool.extend(new_k, new_v)
        assert pool.blocks[-1].is_full
        assert pool.total_tokens == 256

    def test_extend_creates_new_blocks(self) -> None:
        keys, vals = _kv(4, 2, 256, 64)
        pool = BlockPool(agent_id="a", model_id="m")
        pool.segment(keys, vals)
        assert pool.n_blocks == 1

        new_k, new_v = _kv(4, 2, 300, 64)
        pool.extend(new_k, new_v)
        assert pool.n_blocks == 3
        assert pool.total_tokens == 556