""" ENGRAM Protocol — Serializer Tests Tests for .eng safetensors serialize/deserialize round-trip (D7). """ from __future__ import annotations from pathlib import Path import pytest import torch from safetensors.torch import load_file from kvcos.core.serializer import EngramSerializer, SerializationError from kvcos.core.types import CompressionMethod from tests.conftest import make_synthetic_kv from kvcos.core.cache_spec import LLAMA_3_1_8B class TestSerializeRoundTrip: """Serialize → deserialize preserves shape, dtype, metadata.""" def test_round_trip_shape(self, tmp_data_dir: Path) -> None: keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256) s = EngramSerializer() eng = tmp_data_dir / "test.eng" s.serialize( keys=keys, values=values, agent_id="test-agent", task_description="unit test", model_id=LLAMA_3_1_8B["model_id"], output_path=eng, compression=CompressionMethod.FP16, ) k_out, v_out, meta = s.deserialize(eng) assert k_out.shape == keys.shape assert v_out.shape == values.shape def test_metadata_fields(self, tmp_data_dir: Path) -> None: keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64) s = EngramSerializer() eng = tmp_data_dir / "meta.eng" s.serialize( keys=keys, values=values, agent_id="agent-42", task_description="metadata check", model_id=LLAMA_3_1_8B["model_id"], output_path=eng, compression=CompressionMethod.Q8_0, ) _, _, meta = s.deserialize(eng) assert meta["agent_id"] == "agent-42" assert meta["task_description"] == "metadata check" assert meta["compression"] == "q8_0" assert meta["n_layers"] == "32" assert meta["model_family"] == "llama" def test_safetensors_loadable(self, tmp_data_dir: Path) -> None: """D7: File must be valid safetensors.""" keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64) s = EngramSerializer() eng = tmp_data_dir / "valid.eng" s.serialize( keys=keys, values=values, agent_id="test", task_description="safetensors check", model_id=LLAMA_3_1_8B["model_id"], output_path=eng, compression=CompressionMethod.FP16, ) tensors = load_file(str(eng)) assert "layer_0_keys" in tensors assert "layer_0_values" in tensors def test_result_dict(self, tmp_data_dir: Path) -> None: keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64) s = EngramSerializer() eng = tmp_data_dir / "result.eng" result = s.serialize( keys=keys, values=values, agent_id="test", task_description="result check", model_id=LLAMA_3_1_8B["model_id"], output_path=eng, ) assert "cache_id" in result assert result["size_bytes"] > 0 assert result["n_layers"] == 32 class TestSerializerErrors: """Edge cases and error handling.""" def test_shape_mismatch_raises(self, tmp_data_dir: Path) -> None: keys = torch.randn(32, 8, 64, 128, dtype=torch.float16) values = torch.randn(32, 8, 32, 128, dtype=torch.float16) s = EngramSerializer() with pytest.raises(SerializationError, match="mismatch"): s.serialize( keys=keys, values=values, agent_id="t", task_description="t", model_id="test", output_path=tmp_data_dir / "bad.eng", ) def test_3d_tensor_raises(self, tmp_data_dir: Path) -> None: keys = torch.randn(8, 64, 128, dtype=torch.float16) s = EngramSerializer() with pytest.raises(SerializationError, match="4D"): s.serialize( keys=keys, values=keys, agent_id="t", task_description="t", model_id="test", output_path=tmp_data_dir / "bad.eng", ) def test_missing_file_raises(self, tmp_data_dir: Path) -> None: s = EngramSerializer() with pytest.raises(SerializationError, match="not found"): s.deserialize(tmp_data_dir / "nonexistent.eng")