Hanrui / sglang /test /registered /disaggregation /test_specv2_kvcache_offloading.py
Lekr0's picture
Add files using upload-large-folder tool
61ba51e verified
"""
Unit tests for _release_finished_req in DecodeKVCacheOffloadManager.
Verifies that over-allocated KV cache slots (from speculative decoding v2)
are correctly freed when a request finishes, preventing GPU memory leaks.
Requires: torch, sglang (run in an environment with sglang installed)
"""
import unittest
from unittest.mock import MagicMock
import torch
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
DecodeKVCacheOffloadManager,
)
from sglang.test.ci.ci_register import register_cuda_ci
register_cuda_ci(est_time=10, suite="stage-b-test-small-1-gpu")
def _make_mock_req(
req_pool_idx: int,
kv_committed_len: int,
kv_allocated_len: int,
prefix_indices_len: int = 0,
):
"""Create a mock Req with the KV cache state needed for testing."""
req = MagicMock()
req.req_pool_idx = req_pool_idx
req.kv_committed_len = kv_committed_len
req.kv_allocated_len = kv_allocated_len
req.kv_committed_freed = False
req.kv_overallocated_freed = False
req.prefix_indices = list(range(prefix_indices_len))
def pop_committed():
assert not req.kv_committed_freed
req.kv_committed_freed = True
return req.kv_committed_len
def pop_overallocated():
assert not req.kv_overallocated_freed
req.kv_overallocated_freed = True
return req.kv_committed_len, req.kv_allocated_len
req.pop_committed_kv_cache = pop_committed
req.pop_overallocated_kv_cache = pop_overallocated
return req
def _make_manager(pool_size: int, page_size: int = 1):
"""Create a DecodeKVCacheOffloadManager with mock pools for testing."""
# Build a real req_to_token tensor so indexing works
req_to_token = torch.arange(pool_size, dtype=torch.int64).unsqueeze(0)
req_to_token_pool = MagicMock()
req_to_token_pool.req_to_token = req_to_token
freed_indices = []
allocator = MagicMock()
allocator.free = MagicMock(
side_effect=lambda idx: freed_indices.append(idx.clone())
)
tree_cache = MagicMock()
tree_cache.protected_size_ = 0
# Bypass __init__ entirely and set attributes directly
manager = object.__new__(DecodeKVCacheOffloadManager)
manager.req_to_token_pool = req_to_token_pool
manager.token_to_kv_pool_allocator = allocator
manager.page_size = page_size
manager.tree_cache = tree_cache
return manager, freed_indices
class TestReleaseFinishedReq(unittest.TestCase):
"""Tests for _release_finished_req overallocation cleanup."""
def test_no_overallocation(self):
"""Without spec v2, kv_committed == kv_allocated; no extra free."""
manager, freed = _make_manager(pool_size=32)
req = _make_mock_req(
req_pool_idx=0,
kv_committed_len=20,
kv_allocated_len=20, # no overallocation
)
prefill_offloaded_len = 8
manager._release_finished_req(req, prefill_offloaded_len)
# Only one free call: the committed range [8:20]
self.assertEqual(len(freed), 1)
expected = torch.arange(8, 20, dtype=torch.int64)
self.assertTrue(torch.equal(freed[0], expected))
manager.req_to_token_pool.free.assert_called_once_with(req)
def test_with_overallocation(self):
"""With spec v2, overallocated slots [committed:allocated] must be freed."""
manager, freed = _make_manager(pool_size=32)
req = _make_mock_req(
req_pool_idx=0,
kv_committed_len=20,
kv_allocated_len=28, # 8 over-allocated slots
)
prefill_offloaded_len = 8
manager._release_finished_req(req, prefill_offloaded_len)
# Two free calls: committed [8:20] and overallocated [20:28]
self.assertEqual(len(freed), 2)
expected_committed = torch.arange(8, 20, dtype=torch.int64)
expected_overalloc = torch.arange(20, 28, dtype=torch.int64)
self.assertTrue(torch.equal(freed[0], expected_committed))
self.assertTrue(torch.equal(freed[1], expected_overalloc))
manager.req_to_token_pool.free.assert_called_once_with(req)
def test_overallocation_with_page_alignment(self):
"""With page_size > 1, start of overallocated range is ceil-aligned."""
page_size = 4
manager, freed = _make_manager(pool_size=32, page_size=page_size)
req = _make_mock_req(
req_pool_idx=0,
kv_committed_len=10, # not page-aligned
kv_allocated_len=28,
)
prefill_offloaded_len = 4
manager._release_finished_req(req, prefill_offloaded_len)
# Committed range [4:10]
# Overallocated: start_p = ceil_align(10, 4) = 12, end_p = 28 => [12:28]
self.assertEqual(len(freed), 2)
expected_committed = torch.arange(4, 10, dtype=torch.int64)
expected_overalloc = torch.arange(12, 28, dtype=torch.int64)
self.assertTrue(torch.equal(freed[0], expected_committed))
self.assertTrue(torch.equal(freed[1], expected_overalloc))
def test_overallocation_page_aligned_noop(self):
"""When ceil_align(committed, page_size) >= allocated, no overalloc free."""
page_size = 4
manager, freed = _make_manager(pool_size=32, page_size=page_size)
req = _make_mock_req(
req_pool_idx=0,
kv_committed_len=10, # ceil_align(10, 4) = 12
kv_allocated_len=12, # same as aligned start
)
prefill_offloaded_len = 4
manager._release_finished_req(req, prefill_offloaded_len)
# Only committed [4:10], no overalloc because start_p == end_p
self.assertEqual(len(freed), 1)
expected_committed = torch.arange(4, 10, dtype=torch.int64)
self.assertTrue(torch.equal(freed[0], expected_committed))
def test_prefix_indices_decremented(self):
"""protected_size_ is decremented by len(req.prefix_indices)."""
manager, _ = _make_manager(pool_size=32)
manager.tree_cache.protected_size_ = 10
req = _make_mock_req(
req_pool_idx=0,
kv_committed_len=20,
kv_allocated_len=20,
prefix_indices_len=5,
)
manager._release_finished_req(req, prefill_offloaded_len=0)
self.assertEqual(manager.tree_cache.protected_size_, 5)
if __name__ == "__main__":
unittest.main()