File size: 2,949 Bytes
61ba51e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import unittest
from sglang.srt.utils.gauge_histogram import BucketLabels
class TestBucketLabels(unittest.TestCase):
"""Test BucketLabels with hardcoded expected values."""
def test_labels_basic(self):
buckets = BucketLabels([10, 30, 60])
self.assertEqual(
list(buckets),
[("0", "10"), ("10", "30"), ("30", "60"), ("60", "+Inf")],
)
def test_labels_single_bound(self):
buckets = BucketLabels([100])
self.assertEqual(list(buckets), [("0", "100"), ("100", "+Inf")])
def test_labels_many_bounds(self):
buckets = BucketLabels([1, 2, 5, 10])
self.assertEqual(
list(buckets),
[("0", "1"), ("1", "2"), ("2", "5"), ("5", "10"), ("10", "+Inf")],
)
def test_len(self):
buckets = BucketLabels([10, 30, 60])
self.assertEqual(len(buckets), 4)
class TestBucketLabelsCounts(unittest.TestCase):
"""Test BucketLabels.compute_bucket_counts with hardcoded expected values."""
def test_empty_observations(self):
buckets = BucketLabels([10, 30, 60])
self.assertEqual(buckets.compute_bucket_counts([]), [0, 0, 0, 0])
def test_single_value_first_bucket(self):
# bounds: [10, 30, 60] -> buckets: (0,10], (10,30], (30,60], (60,+Inf]
buckets = BucketLabels([10, 30, 60])
self.assertEqual(buckets.compute_bucket_counts([5]), [1, 0, 0, 0])
def test_single_value_last_bucket(self):
buckets = BucketLabels([10, 30, 60])
self.assertEqual(buckets.compute_bucket_counts([100]), [0, 0, 0, 1])
def test_exact_boundary_values(self):
# Values at exact boundaries: 10 -> (0,10], 30 -> (10,30], 60 -> (30,60]
buckets = BucketLabels([10, 30, 60])
self.assertEqual(buckets.compute_bucket_counts([10, 30, 60]), [1, 1, 1, 0])
def test_just_above_boundary(self):
# 11 -> (10,30], 31 -> (30,60], 61 -> (60,+Inf]
buckets = BucketLabels([10, 30, 60])
self.assertEqual(buckets.compute_bucket_counts([11, 31, 61]), [0, 1, 1, 1])
def test_multiple_values_same_bucket(self):
buckets = BucketLabels([10, 30, 60])
self.assertEqual(buckets.compute_bucket_counts([1, 2, 3, 4, 5]), [5, 0, 0, 0])
def test_all_overflow(self):
buckets = BucketLabels([10, 30, 60])
self.assertEqual(buckets.compute_bucket_counts([100, 200, 300]), [0, 0, 0, 3])
def test_distribution(self):
# 5 (<=10), 10 (<=10), 15 (<=30), 40 (<=60), 100 (+Inf)
buckets = BucketLabels([10, 30, 60])
self.assertEqual(
buckets.compute_bucket_counts([5, 10, 15, 40, 100]), [2, 1, 1, 1]
)
def test_float_values(self):
# 9.9 -> (0,10], 10.1 -> (10,30], 30.5 -> (30,60]
buckets = BucketLabels([10, 30, 60])
self.assertEqual(buckets.compute_bucket_counts([9.9, 10.1, 30.5]), [1, 1, 1, 0])
if __name__ == "__main__":
unittest.main()
|