|
|
""" |
|
|
Integration tests for JaggedKVCache with inference pipeline. |
|
|
|
|
|
Run with: pytest tests/test_cache_integration.py -v |
|
|
""" |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
from typing import List, Optional |
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
sys.path.insert(0, "/home/fvalade/workspace/DSSD_demo") |
|
|
|
|
|
from src.jagged_cache import JaggedKVCache |
|
|
|
|
|
|
|
|
class TestJaggedKVCacheProduction: |
|
|
"""Test the production JaggedKVCache implementation.""" |
|
|
|
|
|
@pytest.fixture |
|
|
def cache(self): |
|
|
"""Create a test cache.""" |
|
|
return JaggedKVCache( |
|
|
num_layers=8, |
|
|
batch_size=1, |
|
|
num_kv_heads=4, |
|
|
head_dim=64, |
|
|
device="cpu", |
|
|
dtype=torch.float32, |
|
|
) |
|
|
|
|
|
@pytest.fixture |
|
|
def sample_kv(self): |
|
|
"""Create sample KV tensors.""" |
|
|
|
|
|
def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64): |
|
|
k = torch.randn(batch_size, num_heads, seq_len, head_dim) |
|
|
v = torch.randn(batch_size, num_heads, seq_len, head_dim) |
|
|
return k, v |
|
|
|
|
|
return _make_kv |
|
|
|
|
|
def test_filled_positions_tracking(self, cache, sample_kv): |
|
|
"""Test that filled_positions correctly tracks which positions are filled.""" |
|
|
|
|
|
k, v = sample_kv() |
|
|
cache.update(0, k, v, torch.tensor([0])) |
|
|
|
|
|
assert cache.has_position(0, 0) == True |
|
|
assert cache.has_position(0, 1) == False |
|
|
assert cache.has_position(1, 0) == False |
|
|
|
|
|
def test_needs_fill(self, cache, sample_kv): |
|
|
"""Test needs_fill correctly identifies missing positions.""" |
|
|
|
|
|
k, v = sample_kv() |
|
|
cache.update(0, k, v, torch.tensor([0])) |
|
|
|
|
|
|
|
|
assert cache.needs_fill(0, [0]) == False |
|
|
|
|
|
|
|
|
assert cache.needs_fill(0, [1]) == True |
|
|
|
|
|
|
|
|
assert cache.needs_fill(1, [0]) == True |
|
|
|
|
|
def test_get_unfilled_positions(self, cache, sample_kv): |
|
|
"""Test getting unfilled positions.""" |
|
|
|
|
|
k, v = sample_kv() |
|
|
cache.update(0, k, v, torch.tensor([0])) |
|
|
k, v = sample_kv() |
|
|
cache.update(0, k, v, torch.tensor([2])) |
|
|
|
|
|
|
|
|
unfilled = cache.get_unfilled_positions(0, 4) |
|
|
assert unfilled == [1, 3] |
|
|
|
|
|
def test_truncate_clears_filled_positions(self, cache, sample_kv): |
|
|
"""Test that truncation also clears filled_positions.""" |
|
|
|
|
|
for pos in range(5): |
|
|
k, v = sample_kv() |
|
|
cache.update(0, k, v, torch.tensor([pos])) |
|
|
|
|
|
assert cache.has_position(0, 4) == True |
|
|
|
|
|
|
|
|
cache.truncate_from(3) |
|
|
|
|
|
|
|
|
assert cache.has_position(0, 2) == True |
|
|
assert cache.has_position(0, 3) == False |
|
|
assert cache.has_position(0, 4) == False |
|
|
|
|
|
def test_clone_copies_filled_positions(self, cache, sample_kv): |
|
|
"""Test that clone also copies filled_positions.""" |
|
|
k, v = sample_kv() |
|
|
cache.update(0, k, v, torch.tensor([0])) |
|
|
|
|
|
cloned = cache.clone() |
|
|
|
|
|
assert cloned.has_position(0, 0) == True |
|
|
|
|
|
|
|
|
k, v = sample_kv() |
|
|
cache.update(0, k, v, torch.tensor([1])) |
|
|
|
|
|
|
|
|
assert cache.has_position(0, 1) == True |
|
|
assert cloned.has_position(0, 1) == False |
|
|
|
|
|
def test_reset(self, cache, sample_kv): |
|
|
"""Test that reset clears everything.""" |
|
|
k, v = sample_kv() |
|
|
cache.update(0, k, v, torch.tensor([0])) |
|
|
|
|
|
cache.reset() |
|
|
|
|
|
assert cache.get_kv(0) is None |
|
|
assert cache.get_seq_length(0) == 0 |
|
|
assert cache.has_position(0, 0) == False |
|
|
|
|
|
|
|
|
class TestLazyFillScenario: |
|
|
"""Test realistic lazy fill scenarios.""" |
|
|
|
|
|
@pytest.fixture |
|
|
def cache(self): |
|
|
return JaggedKVCache( |
|
|
num_layers=8, |
|
|
batch_size=1, |
|
|
num_kv_heads=4, |
|
|
head_dim=64, |
|
|
device="cpu", |
|
|
dtype=torch.float32, |
|
|
) |
|
|
|
|
|
@pytest.fixture |
|
|
def sample_kv(self): |
|
|
def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64): |
|
|
k = torch.randn(batch_size, num_heads, seq_len, head_dim) |
|
|
v = torch.randn(batch_size, num_heads, seq_len, head_dim) |
|
|
return k, v |
|
|
|
|
|
return _make_kv |
|
|
|
|
|
def test_lazy_fill_scenario(self, cache, sample_kv): |
|
|
""" |
|
|
Simulate: |
|
|
- Prefill prompt (positions 0-4) through all layers |
|
|
- Draft token 5 exiting at layer 2 |
|
|
- Draft token 6 exiting at layer 6 (needs lazy fill) |
|
|
""" |
|
|
|
|
|
for pos in range(5): |
|
|
for layer_idx in range(8): |
|
|
k, v = sample_kv() |
|
|
cache.update(layer_idx, k, v, torch.tensor([pos])) |
|
|
|
|
|
|
|
|
for layer_idx in range(8): |
|
|
assert cache.get_seq_length(layer_idx) == 5 |
|
|
for pos in range(5): |
|
|
assert cache.has_position(layer_idx, pos) |
|
|
|
|
|
|
|
|
for layer_idx in range(3): |
|
|
k, v = sample_kv() |
|
|
cache.update(layer_idx, k, v, torch.tensor([5])) |
|
|
|
|
|
|
|
|
assert cache.has_position(0, 5) |
|
|
assert cache.has_position(2, 5) |
|
|
assert not cache.has_position(3, 5) |
|
|
|
|
|
|
|
|
|
|
|
missing_at_layer_6 = cache.get_missing_layers(5, 6) |
|
|
|
|
|
|
|
|
assert 3 in missing_at_layer_6 |
|
|
assert 6 in missing_at_layer_6 |
|
|
assert 0 not in missing_at_layer_6 |
|
|
|
|
|
|
|
|
unfilled = cache.get_unfilled_positions(6, 6) |
|
|
assert 5 in unfilled |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main([__file__, "-v"]) |
|
|
|