File size: 6,372 Bytes
61ba51e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """
Unit tests for _release_finished_req in DecodeKVCacheOffloadManager.
Verifies that over-allocated KV cache slots (from speculative decoding v2)
are correctly freed when a request finishes, preventing GPU memory leaks.
Requires: torch, sglang (run in an environment with sglang installed)
"""
import unittest
from unittest.mock import MagicMock
import torch
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
DecodeKVCacheOffloadManager,
)
from sglang.test.ci.ci_register import register_cuda_ci
register_cuda_ci(est_time=10, suite="stage-b-test-small-1-gpu")
def _make_mock_req(
req_pool_idx: int,
kv_committed_len: int,
kv_allocated_len: int,
prefix_indices_len: int = 0,
):
"""Create a mock Req with the KV cache state needed for testing."""
req = MagicMock()
req.req_pool_idx = req_pool_idx
req.kv_committed_len = kv_committed_len
req.kv_allocated_len = kv_allocated_len
req.kv_committed_freed = False
req.kv_overallocated_freed = False
req.prefix_indices = list(range(prefix_indices_len))
def pop_committed():
assert not req.kv_committed_freed
req.kv_committed_freed = True
return req.kv_committed_len
def pop_overallocated():
assert not req.kv_overallocated_freed
req.kv_overallocated_freed = True
return req.kv_committed_len, req.kv_allocated_len
req.pop_committed_kv_cache = pop_committed
req.pop_overallocated_kv_cache = pop_overallocated
return req
def _make_manager(pool_size: int, page_size: int = 1):
"""Create a DecodeKVCacheOffloadManager with mock pools for testing."""
# Build a real req_to_token tensor so indexing works
req_to_token = torch.arange(pool_size, dtype=torch.int64).unsqueeze(0)
req_to_token_pool = MagicMock()
req_to_token_pool.req_to_token = req_to_token
freed_indices = []
allocator = MagicMock()
allocator.free = MagicMock(
side_effect=lambda idx: freed_indices.append(idx.clone())
)
tree_cache = MagicMock()
tree_cache.protected_size_ = 0
# Bypass __init__ entirely and set attributes directly
manager = object.__new__(DecodeKVCacheOffloadManager)
manager.req_to_token_pool = req_to_token_pool
manager.token_to_kv_pool_allocator = allocator
manager.page_size = page_size
manager.tree_cache = tree_cache
return manager, freed_indices
class TestReleaseFinishedReq(unittest.TestCase):
"""Tests for _release_finished_req overallocation cleanup."""
def test_no_overallocation(self):
"""Without spec v2, kv_committed == kv_allocated; no extra free."""
manager, freed = _make_manager(pool_size=32)
req = _make_mock_req(
req_pool_idx=0,
kv_committed_len=20,
kv_allocated_len=20, # no overallocation
)
prefill_offloaded_len = 8
manager._release_finished_req(req, prefill_offloaded_len)
# Only one free call: the committed range [8:20]
self.assertEqual(len(freed), 1)
expected = torch.arange(8, 20, dtype=torch.int64)
self.assertTrue(torch.equal(freed[0], expected))
manager.req_to_token_pool.free.assert_called_once_with(req)
def test_with_overallocation(self):
"""With spec v2, overallocated slots [committed:allocated] must be freed."""
manager, freed = _make_manager(pool_size=32)
req = _make_mock_req(
req_pool_idx=0,
kv_committed_len=20,
kv_allocated_len=28, # 8 over-allocated slots
)
prefill_offloaded_len = 8
manager._release_finished_req(req, prefill_offloaded_len)
# Two free calls: committed [8:20] and overallocated [20:28]
self.assertEqual(len(freed), 2)
expected_committed = torch.arange(8, 20, dtype=torch.int64)
expected_overalloc = torch.arange(20, 28, dtype=torch.int64)
self.assertTrue(torch.equal(freed[0], expected_committed))
self.assertTrue(torch.equal(freed[1], expected_overalloc))
manager.req_to_token_pool.free.assert_called_once_with(req)
def test_overallocation_with_page_alignment(self):
"""With page_size > 1, start of overallocated range is ceil-aligned."""
page_size = 4
manager, freed = _make_manager(pool_size=32, page_size=page_size)
req = _make_mock_req(
req_pool_idx=0,
kv_committed_len=10, # not page-aligned
kv_allocated_len=28,
)
prefill_offloaded_len = 4
manager._release_finished_req(req, prefill_offloaded_len)
# Committed range [4:10]
# Overallocated: start_p = ceil_align(10, 4) = 12, end_p = 28 => [12:28]
self.assertEqual(len(freed), 2)
expected_committed = torch.arange(4, 10, dtype=torch.int64)
expected_overalloc = torch.arange(12, 28, dtype=torch.int64)
self.assertTrue(torch.equal(freed[0], expected_committed))
self.assertTrue(torch.equal(freed[1], expected_overalloc))
def test_overallocation_page_aligned_noop(self):
"""When ceil_align(committed, page_size) >= allocated, no overalloc free."""
page_size = 4
manager, freed = _make_manager(pool_size=32, page_size=page_size)
req = _make_mock_req(
req_pool_idx=0,
kv_committed_len=10, # ceil_align(10, 4) = 12
kv_allocated_len=12, # same as aligned start
)
prefill_offloaded_len = 4
manager._release_finished_req(req, prefill_offloaded_len)
# Only committed [4:10], no overalloc because start_p == end_p
self.assertEqual(len(freed), 1)
expected_committed = torch.arange(4, 10, dtype=torch.int64)
self.assertTrue(torch.equal(freed[0], expected_committed))
def test_prefix_indices_decremented(self):
"""protected_size_ is decremented by len(req.prefix_indices)."""
manager, _ = _make_manager(pool_size=32)
manager.tree_cache.protected_size_ = 10
req = _make_mock_req(
req_pool_idx=0,
kv_committed_len=20,
kv_allocated_len=20,
prefix_indices_len=5,
)
manager._release_finished_req(req, prefill_offloaded_len=0)
self.assertEqual(manager.tree_cache.protected_size_, 5)
if __name__ == "__main__":
unittest.main()
|