| """ |
| ENGRAM Protocol — ISWA Blob Parser Tests |
| Tests for multi-section KV cache parsing (Gemma 4 ISWA format). |
| |
| Uses synthetic ISWA blobs from conftest.make_synthetic_iswa_blob(). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import pytest |
| import torch |
|
|
| from kvcos.core.blob_parser import ( |
| BlobParseError, |
| ParsedKVCache, |
| ParsedMultiSectionCache, |
| parse_multi_section_blob, |
| parse_state_blob, |
| ) |
| from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec |
| from tests.conftest import ( |
| GEMMA4_GLOBAL_SECTION, |
| GEMMA4_SECTIONS, |
| GEMMA4_SWA_SECTION, |
| make_synthetic_iswa_blob, |
| ) |
|
|
|
|
| class TestParseMultiSectionBlob: |
| """Parse ISWA blobs with multiple KV cache sections.""" |
|
|
| def test_parse_gemma4_shape(self) -> None: |
| blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4) |
| result = parse_multi_section_blob(blob, GEMMA4_SECTIONS) |
|
|
| assert len(result.sections) == 2 |
|
|
| |
| s0 = result.sections[0] |
| assert s0.keys.shape == (5, 2, 4, 512) |
| assert s0.values.shape == (5, 2, 4, 512) |
|
|
| |
| s1 = result.sections[1] |
| assert s1.keys.shape == (25, 8, 4, 256) |
| assert s1.values.shape == (25, 8, 4, 256) |
|
|
| def test_parse_metadata(self) -> None: |
| blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4) |
| result = parse_multi_section_blob(blob, GEMMA4_SECTIONS) |
|
|
| assert result.arch == "gemma4" |
| assert result.n_sections == 2 |
| assert result.total_layers == 30 |
|
|
| assert result.sections[0].n_layers == 5 |
| assert result.sections[0].arch == "gemma4" |
| assert result.sections[1].n_layers == 25 |
|
|
| def test_parse_cells(self) -> None: |
| blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4) |
| result = parse_multi_section_blob(blob, GEMMA4_SECTIONS) |
|
|
| for sec in result.sections: |
| assert sec.n_cells == 4 |
| assert len(sec.cells) == 4 |
| assert sec.cells[0].pos == 0 |
| assert sec.cells[3].pos == 3 |
|
|
| def test_dtype_float16(self) -> None: |
| blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=2) |
| result = parse_multi_section_blob(blob, GEMMA4_SECTIONS) |
|
|
| for sec in result.sections: |
| assert sec.keys.dtype == torch.float16 |
| assert sec.values.dtype == torch.float16 |
|
|
| def test_different_cell_counts(self) -> None: |
| blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=8) |
| result = parse_multi_section_blob(blob, GEMMA4_SECTIONS) |
|
|
| assert result.sections[0].n_cells == 8 |
| assert result.sections[1].n_cells == 8 |
|
|
| def test_non_transposed_v(self) -> None: |
| blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=2, v_trans=False) |
| result = parse_multi_section_blob(blob, GEMMA4_SECTIONS) |
|
|
| for sec in result.sections: |
| assert sec.v_trans is False |
|
|
| def test_single_section_works(self) -> None: |
| """Single-section ISWA parse should work identically to standard.""" |
| single = (GEMMA4_GLOBAL_SECTION,) |
| blob = make_synthetic_iswa_blob(single, n_cells=4) |
| result = parse_multi_section_blob(blob, single) |
|
|
| assert len(result.sections) == 1 |
| assert result.sections[0].keys.shape == (5, 2, 4, 512) |
|
|
|
|
| class TestParseMultiSectionErrors: |
| """Error handling for ISWA blob parsing.""" |
|
|
| def test_section_mismatch_raises(self) -> None: |
| """Blob has 2 sections but we pass specs for 3.""" |
| blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4) |
| three_sections = GEMMA4_SECTIONS + (GEMMA4_GLOBAL_SECTION,) |
| with pytest.raises(BlobParseError, match="Expected 3.*got 2"): |
| parse_multi_section_blob(blob, three_sections) |
|
|
| def test_truncated_blob_raises(self) -> None: |
| blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4) |
| with pytest.raises(BlobParseError): |
| parse_multi_section_blob(blob[:100], GEMMA4_SECTIONS) |
|
|
| def test_wrong_dimensions_raises(self) -> None: |
| """Pass wrong KV head count for a section.""" |
| wrong_sections = ( |
| CacheSection(AttentionType.FULL, 5, 4, 512), |
| GEMMA4_SWA_SECTION, |
| ) |
| blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4) |
| with pytest.raises(BlobParseError): |
| parse_multi_section_blob(blob, wrong_sections) |
|
|
|
|
| class TestStandardBlobBackwardCompat: |
| """Ensure parse_state_blob still works for single-stream blobs.""" |
|
|
| def test_single_stream_still_works(self) -> None: |
| from tests.test_blob_parser import _make_blob |
|
|
| blob = _make_blob(16, 32, 8, 128) |
| result = parse_state_blob(blob, n_kv_heads=8, head_dim=128) |
| assert result.keys.shape == (32, 8, 16, 128) |
|
|