File size: 3,623 Bytes
2ece486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
ENGRAM Protocol — State Extractor Tests
Tests for all 3 EGR extraction modes (D3).
"""

from __future__ import annotations

import torch

from kvcos.core.cache_spec import LLAMA_3_1_8B, PHI_3_MINI
from kvcos.core.types import StateExtractionMode
from kvcos.core.state_extractor import MARStateExtractor
from tests.conftest import make_synthetic_kv


class TestMeanPool:
    """mean_pool: mean over layers, heads, context → [head_dim]."""

    def test_output_dim(self) -> None:
        keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
        result = ext.extract(keys, LLAMA_3_1_8B)
        assert result.state_vec.shape == (128,)

    def test_output_dim_api(self) -> None:
        ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
        assert ext.output_dim(LLAMA_3_1_8B) == 128

    def test_l2_norm_positive(self) -> None:
        keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
        result = ext.extract(keys, LLAMA_3_1_8B)
        assert result.l2_norm > 0

    def test_deterministic(self) -> None:
        keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
        r1 = ext.extract(keys, LLAMA_3_1_8B)
        r2 = ext.extract(keys, LLAMA_3_1_8B)
        assert torch.equal(r1.state_vec, r2.state_vec)


class TestSVDProject:
    """svd_project: truncated SVD, rank-160 → [rank]."""

    def test_output_dim(self) -> None:
        keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        ext = MARStateExtractor(mode=StateExtractionMode.SVD_PROJECT, rank=160)
        result = ext.extract(keys, LLAMA_3_1_8B)
        assert result.state_vec.shape == (128,)  # clamped to head_dim

    def test_projection_stored(self) -> None:
        keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        ext = MARStateExtractor(mode=StateExtractionMode.SVD_PROJECT, rank=160)
        ext.extract(keys, LLAMA_3_1_8B)
        proj = ext.last_projection
        assert proj is not None
        assert 0.0 < proj.explained_variance_ratio <= 1.0

    def test_n_layers_used(self) -> None:
        keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        ext = MARStateExtractor(mode=StateExtractionMode.SVD_PROJECT)
        result = ext.extract(keys, LLAMA_3_1_8B)
        assert result.n_layers_used == 24  # layers 8-31


class TestXKVProject:
    """xkv_project: grouped cross-layer SVD."""

    def test_output_dim(self) -> None:
        keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        ext = MARStateExtractor(mode=StateExtractionMode.XKV_PROJECT, rank=160)
        result = ext.extract(keys, LLAMA_3_1_8B)
        expected_dim = ext.output_dim(LLAMA_3_1_8B)
        assert result.state_vec.shape == (expected_dim,)

    def test_different_from_mean_pool(self) -> None:
        keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        ext_mp = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
        ext_xkv = MARStateExtractor(mode=StateExtractionMode.XKV_PROJECT)
        r_mp = ext_mp.extract(keys, LLAMA_3_1_8B)
        r_xkv = ext_xkv.extract(keys, LLAMA_3_1_8B)
        assert r_mp.state_vec.shape != r_xkv.state_vec.shape

    def test_phi3_works(self) -> None:
        keys, _ = make_synthetic_kv(PHI_3_MINI, ctx_len=64)
        ext = MARStateExtractor(mode=StateExtractionMode.XKV_PROJECT, rank=96)
        result = ext.extract(keys, PHI_3_MINI)
        assert result.state_vec.dim() == 1
        assert result.state_vec.shape[0] > 0