File size: 4,244 Bytes
2ece486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
ENGRAM Protocol — Serializer Tests
Tests for .eng safetensors serialize/deserialize round-trip (D7).
"""

from __future__ import annotations

from pathlib import Path

import pytest
import torch
from safetensors.torch import load_file

from kvcos.core.serializer import EngramSerializer, SerializationError
from kvcos.core.types import CompressionMethod
from tests.conftest import make_synthetic_kv
from kvcos.core.cache_spec import LLAMA_3_1_8B


class TestSerializeRoundTrip:
    """Serialize → deserialize preserves shape, dtype, metadata."""

    def test_round_trip_shape(self, tmp_data_dir: Path) -> None:
        keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256)
        s = EngramSerializer()
        eng = tmp_data_dir / "test.eng"

        s.serialize(
            keys=keys, values=values,
            agent_id="test-agent", task_description="unit test",
            model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
            compression=CompressionMethod.FP16,
        )
        k_out, v_out, meta = s.deserialize(eng)

        assert k_out.shape == keys.shape
        assert v_out.shape == values.shape

    def test_metadata_fields(self, tmp_data_dir: Path) -> None:
        keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        s = EngramSerializer()
        eng = tmp_data_dir / "meta.eng"

        s.serialize(
            keys=keys, values=values,
            agent_id="agent-42", task_description="metadata check",
            model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
            compression=CompressionMethod.Q8_0,
        )
        _, _, meta = s.deserialize(eng)

        assert meta["agent_id"] == "agent-42"
        assert meta["task_description"] == "metadata check"
        assert meta["compression"] == "q8_0"
        assert meta["n_layers"] == "32"
        assert meta["model_family"] == "llama"

    def test_safetensors_loadable(self, tmp_data_dir: Path) -> None:
        """D7: File must be valid safetensors."""
        keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        s = EngramSerializer()
        eng = tmp_data_dir / "valid.eng"

        s.serialize(
            keys=keys, values=values,
            agent_id="test", task_description="safetensors check",
            model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
            compression=CompressionMethod.FP16,
        )
        tensors = load_file(str(eng))
        assert "layer_0_keys" in tensors
        assert "layer_0_values" in tensors

    def test_result_dict(self, tmp_data_dir: Path) -> None:
        keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        s = EngramSerializer()
        eng = tmp_data_dir / "result.eng"

        result = s.serialize(
            keys=keys, values=values,
            agent_id="test", task_description="result check",
            model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
        )
        assert "cache_id" in result
        assert result["size_bytes"] > 0
        assert result["n_layers"] == 32


class TestSerializerErrors:
    """Edge cases and error handling."""

    def test_shape_mismatch_raises(self, tmp_data_dir: Path) -> None:
        keys = torch.randn(32, 8, 64, 128, dtype=torch.float16)
        values = torch.randn(32, 8, 32, 128, dtype=torch.float16)
        s = EngramSerializer()

        with pytest.raises(SerializationError, match="mismatch"):
            s.serialize(
                keys=keys, values=values,
                agent_id="t", task_description="t",
                model_id="test", output_path=tmp_data_dir / "bad.eng",
            )

    def test_3d_tensor_raises(self, tmp_data_dir: Path) -> None:
        keys = torch.randn(8, 64, 128, dtype=torch.float16)
        s = EngramSerializer()

        with pytest.raises(SerializationError, match="4D"):
            s.serialize(
                keys=keys, values=keys,
                agent_id="t", task_description="t",
                model_id="test", output_path=tmp_data_dir / "bad.eng",
            )

    def test_missing_file_raises(self, tmp_data_dir: Path) -> None:
        s = EngramSerializer()
        with pytest.raises(SerializationError, match="not found"):
            s.deserialize(tmp_data_dir / "nonexistent.eng")