Dssd_Demo / tests /test_cache_integration.py
Florian valade
Track metrics during streaming, remove redundant generation re-runs
33efa44
"""
Integration tests for JaggedKVCache with inference pipeline.
Run with: pytest tests/test_cache_integration.py -v
"""
import pytest
import torch
from typing import List, Optional
# Import from production module
import sys
sys.path.insert(0, "/home/fvalade/workspace/DSSD_demo")
from src.jagged_cache import JaggedKVCache
class TestJaggedKVCacheProduction:
"""Test the production JaggedKVCache implementation."""
@pytest.fixture
def cache(self):
"""Create a test cache."""
return JaggedKVCache(
num_layers=8,
batch_size=1,
num_kv_heads=4,
head_dim=64,
device="cpu",
dtype=torch.float32,
)
@pytest.fixture
def sample_kv(self):
"""Create sample KV tensors."""
def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64):
k = torch.randn(batch_size, num_heads, seq_len, head_dim)
v = torch.randn(batch_size, num_heads, seq_len, head_dim)
return k, v
return _make_kv
def test_filled_positions_tracking(self, cache, sample_kv):
"""Test that filled_positions correctly tracks which positions are filled."""
# Update layer 0 with position 0
k, v = sample_kv()
cache.update(0, k, v, torch.tensor([0]))
assert cache.has_position(0, 0) == True
assert cache.has_position(0, 1) == False
assert cache.has_position(1, 0) == False # Layer 1 not touched
def test_needs_fill(self, cache, sample_kv):
"""Test needs_fill correctly identifies missing positions."""
# Fill layer 0 with position 0
k, v = sample_kv()
cache.update(0, k, v, torch.tensor([0]))
# Layer 0 has position 0, doesn't need fill
assert cache.needs_fill(0, [0]) == False
# Layer 0 doesn't have position 1
assert cache.needs_fill(0, [1]) == True
# Layer 1 has nothing
assert cache.needs_fill(1, [0]) == True
def test_get_unfilled_positions(self, cache, sample_kv):
"""Test getting unfilled positions."""
# Fill positions 0 and 2 for layer 0
k, v = sample_kv()
cache.update(0, k, v, torch.tensor([0]))
k, v = sample_kv()
cache.update(0, k, v, torch.tensor([2]))
# Unfilled up to position 4 should be [1, 3]
unfilled = cache.get_unfilled_positions(0, 4)
assert unfilled == [1, 3]
def test_truncate_clears_filled_positions(self, cache, sample_kv):
"""Test that truncation also clears filled_positions."""
# Fill positions 0-4
for pos in range(5):
k, v = sample_kv()
cache.update(0, k, v, torch.tensor([pos]))
assert cache.has_position(0, 4) == True
# Truncate at position 3
cache.truncate_from(3)
# Positions 3 and 4 should be gone
assert cache.has_position(0, 2) == True
assert cache.has_position(0, 3) == False
assert cache.has_position(0, 4) == False
def test_clone_copies_filled_positions(self, cache, sample_kv):
"""Test that clone also copies filled_positions."""
k, v = sample_kv()
cache.update(0, k, v, torch.tensor([0]))
cloned = cache.clone()
assert cloned.has_position(0, 0) == True
# Modify original
k, v = sample_kv()
cache.update(0, k, v, torch.tensor([1]))
# Clone should be unaffected
assert cache.has_position(0, 1) == True
assert cloned.has_position(0, 1) == False
def test_reset(self, cache, sample_kv):
"""Test that reset clears everything."""
k, v = sample_kv()
cache.update(0, k, v, torch.tensor([0]))
cache.reset()
assert cache.get_kv(0) is None
assert cache.get_seq_length(0) == 0
assert cache.has_position(0, 0) == False
class TestLazyFillScenario:
"""Test realistic lazy fill scenarios."""
@pytest.fixture
def cache(self):
return JaggedKVCache(
num_layers=8,
batch_size=1,
num_kv_heads=4,
head_dim=64,
device="cpu",
dtype=torch.float32,
)
@pytest.fixture
def sample_kv(self):
def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64):
k = torch.randn(batch_size, num_heads, seq_len, head_dim)
v = torch.randn(batch_size, num_heads, seq_len, head_dim)
return k, v
return _make_kv
def test_lazy_fill_scenario(self, cache, sample_kv):
"""
Simulate:
- Prefill prompt (positions 0-4) through all layers
- Draft token 5 exiting at layer 2
- Draft token 6 exiting at layer 6 (needs lazy fill)
"""
# Prefill: positions 0-4 through all 8 layers
for pos in range(5):
for layer_idx in range(8):
k, v = sample_kv()
cache.update(layer_idx, k, v, torch.tensor([pos]))
# Verify prefill complete
for layer_idx in range(8):
assert cache.get_seq_length(layer_idx) == 5
for pos in range(5):
assert cache.has_position(layer_idx, pos)
# Draft token 5, exit at layer 2
for layer_idx in range(3): # Layers 0, 1, 2
k, v = sample_kv()
cache.update(layer_idx, k, v, torch.tensor([5]))
# Position 5 is filled only for layers 0-2
assert cache.has_position(0, 5)
assert cache.has_position(2, 5)
assert not cache.has_position(3, 5)
# Draft token 6, need to exit at layer 6
# Check what positions are missing for layer 6
missing_at_layer_6 = cache.get_missing_layers(5, 6)
# Layers 3-6 are missing position 5
assert 3 in missing_at_layer_6
assert 6 in missing_at_layer_6
assert 0 not in missing_at_layer_6 # Layer 0 has position 5
# Check unfilled positions for layer 6 up to position 6
unfilled = cache.get_unfilled_positions(6, 6)
assert 5 in unfilled # Position 5 is unfilled at layer 6
if __name__ == "__main__":
pytest.main([__file__, "-v"])