File size: 6,189 Bytes
33efa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
Integration tests for JaggedKVCache with inference pipeline.

Run with: pytest tests/test_cache_integration.py -v
"""

import pytest
import torch
from typing import List, Optional

# Import from production module
import sys

sys.path.insert(0, "/home/fvalade/workspace/DSSD_demo")

from src.jagged_cache import JaggedKVCache


class TestJaggedKVCacheProduction:
    """Test the production JaggedKVCache implementation."""

    @pytest.fixture
    def cache(self):
        """Create a test cache."""
        return JaggedKVCache(
            num_layers=8,
            batch_size=1,
            num_kv_heads=4,
            head_dim=64,
            device="cpu",
            dtype=torch.float32,
        )

    @pytest.fixture
    def sample_kv(self):
        """Create sample KV tensors."""

        def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64):
            k = torch.randn(batch_size, num_heads, seq_len, head_dim)
            v = torch.randn(batch_size, num_heads, seq_len, head_dim)
            return k, v

        return _make_kv

    def test_filled_positions_tracking(self, cache, sample_kv):
        """Test that filled_positions correctly tracks which positions are filled."""
        # Update layer 0 with position 0
        k, v = sample_kv()
        cache.update(0, k, v, torch.tensor([0]))

        assert cache.has_position(0, 0) == True
        assert cache.has_position(0, 1) == False
        assert cache.has_position(1, 0) == False  # Layer 1 not touched

    def test_needs_fill(self, cache, sample_kv):
        """Test needs_fill correctly identifies missing positions."""
        # Fill layer 0 with position 0
        k, v = sample_kv()
        cache.update(0, k, v, torch.tensor([0]))

        # Layer 0 has position 0, doesn't need fill
        assert cache.needs_fill(0, [0]) == False

        # Layer 0 doesn't have position 1
        assert cache.needs_fill(0, [1]) == True

        # Layer 1 has nothing
        assert cache.needs_fill(1, [0]) == True

    def test_get_unfilled_positions(self, cache, sample_kv):
        """Test getting unfilled positions."""
        # Fill positions 0 and 2 for layer 0
        k, v = sample_kv()
        cache.update(0, k, v, torch.tensor([0]))
        k, v = sample_kv()
        cache.update(0, k, v, torch.tensor([2]))

        # Unfilled up to position 4 should be [1, 3]
        unfilled = cache.get_unfilled_positions(0, 4)
        assert unfilled == [1, 3]

    def test_truncate_clears_filled_positions(self, cache, sample_kv):
        """Test that truncation also clears filled_positions."""
        # Fill positions 0-4
        for pos in range(5):
            k, v = sample_kv()
            cache.update(0, k, v, torch.tensor([pos]))

        assert cache.has_position(0, 4) == True

        # Truncate at position 3
        cache.truncate_from(3)

        # Positions 3 and 4 should be gone
        assert cache.has_position(0, 2) == True
        assert cache.has_position(0, 3) == False
        assert cache.has_position(0, 4) == False

    def test_clone_copies_filled_positions(self, cache, sample_kv):
        """Test that clone also copies filled_positions."""
        k, v = sample_kv()
        cache.update(0, k, v, torch.tensor([0]))

        cloned = cache.clone()

        assert cloned.has_position(0, 0) == True

        # Modify original
        k, v = sample_kv()
        cache.update(0, k, v, torch.tensor([1]))

        # Clone should be unaffected
        assert cache.has_position(0, 1) == True
        assert cloned.has_position(0, 1) == False

    def test_reset(self, cache, sample_kv):
        """Test that reset clears everything."""
        k, v = sample_kv()
        cache.update(0, k, v, torch.tensor([0]))

        cache.reset()

        assert cache.get_kv(0) is None
        assert cache.get_seq_length(0) == 0
        assert cache.has_position(0, 0) == False


class TestLazyFillScenario:
    """Test realistic lazy fill scenarios."""

    @pytest.fixture
    def cache(self):
        return JaggedKVCache(
            num_layers=8,
            batch_size=1,
            num_kv_heads=4,
            head_dim=64,
            device="cpu",
            dtype=torch.float32,
        )

    @pytest.fixture
    def sample_kv(self):
        def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64):
            k = torch.randn(batch_size, num_heads, seq_len, head_dim)
            v = torch.randn(batch_size, num_heads, seq_len, head_dim)
            return k, v

        return _make_kv

    def test_lazy_fill_scenario(self, cache, sample_kv):
        """
        Simulate:
        - Prefill prompt (positions 0-4) through all layers
        - Draft token 5 exiting at layer 2
        - Draft token 6 exiting at layer 6 (needs lazy fill)
        """
        # Prefill: positions 0-4 through all 8 layers
        for pos in range(5):
            for layer_idx in range(8):
                k, v = sample_kv()
                cache.update(layer_idx, k, v, torch.tensor([pos]))

        # Verify prefill complete
        for layer_idx in range(8):
            assert cache.get_seq_length(layer_idx) == 5
            for pos in range(5):
                assert cache.has_position(layer_idx, pos)

        # Draft token 5, exit at layer 2
        for layer_idx in range(3):  # Layers 0, 1, 2
            k, v = sample_kv()
            cache.update(layer_idx, k, v, torch.tensor([5]))

        # Position 5 is filled only for layers 0-2
        assert cache.has_position(0, 5)
        assert cache.has_position(2, 5)
        assert not cache.has_position(3, 5)

        # Draft token 6, need to exit at layer 6
        # Check what positions are missing for layer 6
        missing_at_layer_6 = cache.get_missing_layers(5, 6)

        # Layers 3-6 are missing position 5
        assert 3 in missing_at_layer_6
        assert 6 in missing_at_layer_6
        assert 0 not in missing_at_layer_6  # Layer 0 has position 5

        # Check unfilled positions for layer 6 up to position 6
        unfilled = cache.get_unfilled_positions(6, 6)
        assert 5 in unfilled  # Position 5 is unfilled at layer 6


if __name__ == "__main__":
    pytest.main([__file__, "-v"])