File size: 4,845 Bytes
2ece486 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """
ENGRAM Protocol — ISWA Blob Parser Tests
Tests for multi-section KV cache parsing (Gemma 4 ISWA format).
Uses synthetic ISWA blobs from conftest.make_synthetic_iswa_blob().
"""
from __future__ import annotations
import pytest
import torch
from kvcos.core.blob_parser import (
BlobParseError,
ParsedKVCache,
ParsedMultiSectionCache,
parse_multi_section_blob,
parse_state_blob,
)
from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec
from tests.conftest import (
GEMMA4_GLOBAL_SECTION,
GEMMA4_SECTIONS,
GEMMA4_SWA_SECTION,
make_synthetic_iswa_blob,
)
class TestParseMultiSectionBlob:
"""Parse ISWA blobs with multiple KV cache sections."""
def test_parse_gemma4_shape(self) -> None:
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
assert len(result.sections) == 2
# Global section: [5, 2, 4, 512]
s0 = result.sections[0]
assert s0.keys.shape == (5, 2, 4, 512)
assert s0.values.shape == (5, 2, 4, 512)
# SWA section: [25, 8, 4, 256]
s1 = result.sections[1]
assert s1.keys.shape == (25, 8, 4, 256)
assert s1.values.shape == (25, 8, 4, 256)
def test_parse_metadata(self) -> None:
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
assert result.arch == "gemma4"
assert result.n_sections == 2
assert result.total_layers == 30
assert result.sections[0].n_layers == 5
assert result.sections[0].arch == "gemma4"
assert result.sections[1].n_layers == 25
def test_parse_cells(self) -> None:
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
for sec in result.sections:
assert sec.n_cells == 4
assert len(sec.cells) == 4
assert sec.cells[0].pos == 0
assert sec.cells[3].pos == 3
def test_dtype_float16(self) -> None:
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=2)
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
for sec in result.sections:
assert sec.keys.dtype == torch.float16
assert sec.values.dtype == torch.float16
def test_different_cell_counts(self) -> None:
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=8)
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
assert result.sections[0].n_cells == 8
assert result.sections[1].n_cells == 8
def test_non_transposed_v(self) -> None:
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=2, v_trans=False)
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
for sec in result.sections:
assert sec.v_trans is False
def test_single_section_works(self) -> None:
"""Single-section ISWA parse should work identically to standard."""
single = (GEMMA4_GLOBAL_SECTION,)
blob = make_synthetic_iswa_blob(single, n_cells=4)
result = parse_multi_section_blob(blob, single)
assert len(result.sections) == 1
assert result.sections[0].keys.shape == (5, 2, 4, 512)
class TestParseMultiSectionErrors:
"""Error handling for ISWA blob parsing."""
def test_section_mismatch_raises(self) -> None:
"""Blob has 2 sections but we pass specs for 3."""
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
three_sections = GEMMA4_SECTIONS + (GEMMA4_GLOBAL_SECTION,)
with pytest.raises(BlobParseError, match="Expected 3.*got 2"):
parse_multi_section_blob(blob, three_sections)
def test_truncated_blob_raises(self) -> None:
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
with pytest.raises(BlobParseError):
parse_multi_section_blob(blob[:100], GEMMA4_SECTIONS)
def test_wrong_dimensions_raises(self) -> None:
"""Pass wrong KV head count for a section."""
wrong_sections = (
CacheSection(AttentionType.FULL, 5, 4, 512), # wrong: 4 heads not 2
GEMMA4_SWA_SECTION,
)
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
with pytest.raises(BlobParseError):
parse_multi_section_blob(blob, wrong_sections)
class TestStandardBlobBackwardCompat:
"""Ensure parse_state_blob still works for single-stream blobs."""
def test_single_stream_still_works(self) -> None:
from tests.test_blob_parser import _make_blob
blob = _make_blob(16, 32, 8, 128)
result = parse_state_blob(blob, n_kv_heads=8, head_dim=128)
assert result.keys.shape == (32, 8, 16, 128)
|