File size: 3,322 Bytes
8bfcf43
 
cf0a8ed
8bfcf43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Tests for vLLMAtomPlugin — TASK-008."""
import pytest
from apohara_context_forge.serving.atom_plugin import vLLMAtomPlugin, ATOMConfig, PreAttentionHook, PostAttentionHook


class TestATOMConfig:
    """Tests for ATOMConfig."""

    def test_atom_config_defaults(self):
        """ATOMConfig has sensible defaults."""
        config = ATOMConfig()
        assert config.enable_quantization == True
        assert config.enable_anchor_routing == True
        assert config.enable_cla_injection == True
        assert config.quantization_mode == "rotate_kv"


class TestvLLMAtomPlugin:
    """Tests for vLLMAtomPlugin."""

    def test_plugin_initialization(self):
        """Plugin initializes with ATOMConfig."""
        config = ATOMConfig()
        plugin = vLLMAtomPlugin(config)
        assert plugin._config is config
        assert plugin.is_initialized() == False

    def test_initialize_sets_worker_id(self):
        """initialize() sets worker_id and marks initialized."""
        config = ATOMConfig()
        plugin = vLLMAtomPlugin(config)
        plugin.initialize("worker_0", {})
        assert plugin.is_initialized() == True
        stats = plugin.get_stats()
        assert stats["worker_id"] == "worker_0"
        assert stats["initialized"] == True

    def test_pre_attention_hook_returns_dict(self):
        """pre_attention_hook returns metadata dict."""
        config = ATOMConfig(enable_quantization=True)
        hook = PreAttentionHook(config)
        result = hook(["b0", "b1"], [101, 2003], layer_idx=0)
        assert isinstance(result, dict)
        assert result["quantized"] == True
        assert result["pre_rope"] == True  # INVARIANT 10
        assert result["layer_idx"] == 0

    def test_post_attention_hook_returns_dict(self):
        """post_attention_hook returns stats dict."""
        config = ATOMConfig()
        hook = PostAttentionHook(config)
        result = hook(["b0", "b1"], [], layer_idx=0)
        assert isinstance(result, dict)
        assert result["processed_blocks"] == 2
        assert result["layer_idx"] == 0

    def test_plugin_pre_attention_hook_property(self):
        """Plugin exposes pre_attention_hook as property."""
        config = ATOMConfig()
        plugin = vLLMAtomPlugin(config)
        assert hasattr(plugin, "pre_attention_hook")
        assert callable(plugin.pre_attention_hook)

    def test_plugin_post_attention_hook_property(self):
        """Plugin exposes post_attention_hook as property."""
        config = ATOMConfig()
        plugin = vLLMAtomPlugin(config)
        assert hasattr(plugin, "post_attention_hook")
        assert callable(plugin.post_attention_hook)

    def test_get_stats_returns_config_and_state(self):
        """get_stats returns configuration and state."""
        config = ATOMConfig(
            enable_quantization=True,
            enable_anchor_routing=False,
            enable_cla_injection=True,
            quantization_mode="rotate_kv",
        )
        plugin = vLLMAtomPlugin(config)
        plugin.initialize("worker_test", {})

        stats = plugin.get_stats()
        assert stats["initialized"] == True
        assert stats["worker_id"] == "worker_test"
        assert stats["config"]["enable_quantization"] == True
        assert stats["config"]["quantization_mode"] == "rotate_kv"