File size: 612 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | import re
import unittest
import torch
kernel = torch.ops.sgl_kernel
from sglang.test.test_utils import CustomTestCase
class TestGemm(CustomTestCase):
def test_binding(self):
start_id = 1
n_cpu = 6
expected_cores = list(map(str, range(start_id, start_id + n_cpu)))
cpu_ids = ",".join(expected_cores)
output = kernel.init_cpu_threads_env(cpu_ids)
bindings = re.findall(r"OMP tid: \d+, core (\d+)", output)
self.assertEqual(len(bindings), n_cpu)
self.assertEqual(bindings, expected_cores)
if __name__ == "__main__":
unittest.main()
|