Hanrui / sglang /test /registered /debug_utils /test_schedule_simulator.py
Lekr0's picture
Add files using upload-large-folder tool
61ba51e verified
import json
import subprocess
import sys
import tempfile
import unittest
from sglang.srt.debug_utils.schedule_simulator import (
AttentionComputeBalancednessRecorder,
BatchSizeBalancednessRecorder,
FIFOScheduler,
GPUState,
RandomRouter,
RoundRobinRouter,
SimRequest,
SimulationResult,
Simulator,
StepRecord,
StickyRouter,
create_arg_parser,
generate_gsp_requests,
generate_random_requests,
load_from_request_logger,
main,
)
from sglang.test.ci.ci_register import register_cpu_ci
from sglang.test.test_utils import CustomTestCase
register_cpu_ci(est_time=120, suite="default", nightly=True)
# ==================== Non-E2E Tests ====================
class TestSimRequest(CustomTestCase):
def test_basic(self):
req = SimRequest(request_id="r1", input_len=100, output_len=50)
self.assertEqual(req.decoded_tokens, 0)
self.assertEqual(req.seq_len(), 100)
self.assertFalse(req.is_finished())
def test_seq_len_with_decoded(self):
req = SimRequest(
request_id="r1", input_len=100, output_len=50, decoded_tokens=10
)
self.assertEqual(req.seq_len(), 110)
def test_is_finished(self):
req = SimRequest(
request_id="r1", input_len=100, output_len=50, decoded_tokens=50
)
self.assertTrue(req.is_finished())
class TestGPUState(CustomTestCase):
def test_batch_size(self):
gpu = GPUState(gpu_id=0, max_total_tokens=10000)
self.assertEqual(gpu.batch_size(), 0)
gpu.running_requests = [
SimRequest(request_id="r1", input_len=100, output_len=50),
SimRequest(request_id="r2", input_len=200, output_len=100),
]
self.assertEqual(gpu.batch_size(), 2)
def test_total_seq_len(self):
gpu = GPUState(gpu_id=0, max_total_tokens=10000)
gpu.running_requests = [
SimRequest(request_id="r1", input_len=100, output_len=50),
SimRequest(
request_id="r2", input_len=200, output_len=100, decoded_tokens=10
),
]
self.assertEqual(gpu.total_seq_len(), 100 + 210)
def test_total_seq_len_shared_prefix(self):
gpu = GPUState(gpu_id=0, max_total_tokens=10000)
gpu.running_requests = [
SimRequest(
request_id="r1",
input_len=150,
output_len=50,
group_id="g0",
prefix_len=100,
),
SimRequest(
request_id="r2",
input_len=150,
output_len=50,
group_id="g0",
prefix_len=100,
),
]
self.assertEqual(gpu.total_seq_len(), 150 + 50)
def test_total_seq_len_shared_prefix_with_decoded(self):
gpu = GPUState(gpu_id=0, max_total_tokens=10000)
gpu.running_requests = [
SimRequest(
request_id="r1",
input_len=150,
output_len=50,
decoded_tokens=10,
group_id="g0",
prefix_len=100,
),
SimRequest(
request_id="r2",
input_len=150,
output_len=50,
decoded_tokens=5,
group_id="g0",
prefix_len=100,
),
]
self.assertEqual(gpu.total_seq_len(), 160 + 55)
def test_total_seq_len_multiple_groups(self):
gpu = GPUState(gpu_id=0, max_total_tokens=10000)
gpu.running_requests = [
SimRequest(
request_id="r1",
input_len=150,
output_len=50,
group_id="g0",
prefix_len=100,
),
SimRequest(
request_id="r2",
input_len=150,
output_len=50,
group_id="g0",
prefix_len=100,
),
SimRequest(
request_id="r3",
input_len=200,
output_len=50,
group_id="g1",
prefix_len=150,
),
SimRequest(request_id="r4", input_len=80, output_len=20),
]
self.assertEqual(gpu.total_seq_len(), 150 + 50 + 200 + 80)
class TestRouters(CustomTestCase):
def test_round_robin(self):
router = RoundRobinRouter(num_gpus=4)
req = SimRequest(request_id="r1", input_len=100, output_len=50)
results = [router.route(req) for _ in range(8)]
self.assertEqual(results, [0, 1, 2, 3, 0, 1, 2, 3])
def test_random_router(self):
router = RandomRouter(num_gpus=4)
req = SimRequest(request_id="r1", input_len=100, output_len=50)
results = [router.route(req) for _ in range(100)]
self.assertTrue(all(0 <= r < 4 for r in results))
def test_sticky_router_same_group_same_gpu(self):
router = StickyRouter(num_gpus=4)
reqs = [
SimRequest(request_id=f"r{i}", input_len=100, output_len=50, group_id="g0")
for i in range(10)
]
results = [router.route(req) for req in reqs]
self.assertEqual(len(set(results)), 1)
def test_sticky_router_no_group_fallback(self):
router = StickyRouter(num_gpus=4)
reqs = [
SimRequest(request_id=f"r{i}", input_len=100, output_len=50)
for i in range(100)
]
results = [router.route(req) for req in reqs]
self.assertTrue(all(0 <= r < 4 for r in results))
def test_sticky_router_multiple_groups(self):
router = StickyRouter(num_gpus=4)
for group_id in ["g0", "g1", "g2"]:
reqs = [
SimRequest(
request_id=f"{group_id}_r{i}",
input_len=100,
output_len=50,
group_id=group_id,
)
for i in range(5)
]
results = [router.route(req) for req in reqs]
self.assertEqual(len(set(results)), 1)
class TestFIFOScheduler(CustomTestCase):
def test_runs_pending_requests(self):
scheduler = FIFOScheduler()
gpu = GPUState(gpu_id=0, max_total_tokens=10000)
gpu.pending_requests = [
SimRequest(request_id=f"r{i}", input_len=100, output_len=50)
for i in range(3)
]
scheduler.schedule(gpu)
self.assertEqual(len(gpu.running_requests), 3)
self.assertEqual(len(gpu.pending_requests), 0)
def test_respects_token_limit(self):
scheduler = FIFOScheduler()
gpu = GPUState(gpu_id=0, max_total_tokens=250)
gpu.pending_requests = [
SimRequest(request_id=f"r{i}", input_len=100, output_len=50)
for i in range(5)
]
scheduler.schedule(gpu)
self.assertEqual(len(gpu.running_requests), 2)
self.assertEqual(len(gpu.pending_requests), 3)
def test_evicts_lifo_when_over_budget(self):
scheduler = FIFOScheduler()
gpu = GPUState(gpu_id=0, max_total_tokens=250)
gpu.running_requests = [
SimRequest(request_id=f"r{i}", input_len=100, output_len=50)
for i in range(3)
] # 300 tokens total
scheduler.schedule(gpu)
self.assertEqual(len(gpu.running_requests), 2)
self.assertEqual(len(gpu.pending_requests), 1)
self.assertEqual(gpu.pending_requests[0].request_id, "r2")
class TestMetrics(CustomTestCase):
def test_batch_size_balancedness(self):
recorder = BatchSizeBalancednessRecorder()
gpu_states = [GPUState(gpu_id=i, max_total_tokens=10000) for i in range(2)]
gpu_states[0].running_requests = [
SimRequest(request_id="r1", input_len=100, output_len=50)
]
gpu_states[1].running_requests = [
SimRequest(request_id="r2", input_len=100, output_len=50),
SimRequest(request_id="r3", input_len=100, output_len=50),
]
recorder.on_step_end(0, gpu_states)
self.assertAlmostEqual(
recorder.get_summary()["batch_size_balancedness_mean"], 0.75
)
def test_attention_compute_balancedness(self):
recorder = AttentionComputeBalancednessRecorder()
gpu_states = [GPUState(gpu_id=i, max_total_tokens=10000) for i in range(2)]
gpu_states[0].running_requests = [
SimRequest(request_id="r1", input_len=100, output_len=50)
]
gpu_states[1].running_requests = [
SimRequest(request_id="r2", input_len=200, output_len=50)
]
recorder.on_step_end(0, gpu_states)
self.assertAlmostEqual(
recorder.get_summary()["attention_compute_balancedness_mean"], 0.75
)
def test_empty_history(self):
recorder = BatchSizeBalancednessRecorder()
self.assertEqual(recorder.get_summary()["batch_size_balancedness_mean"], 0.0)
def test_all_zero_batch_size(self):
recorder = BatchSizeBalancednessRecorder()
gpu_states = [GPUState(gpu_id=i, max_total_tokens=10000) for i in range(2)]
recorder.on_step_end(0, gpu_states)
self.assertAlmostEqual(
recorder.get_summary()["batch_size_balancedness_mean"], 1.0
)
class TestDataLoader(CustomTestCase):
def test_load_from_request_logger(self):
log_data = [
{"event": "request.received", "rid": "r1", "obj": {"text": "hello"}},
{
"event": "request.finished",
"rid": "r1",
"out": {"meta_info": {"prompt_tokens": 100, "completion_tokens": 50}},
},
{
"event": "request.finished",
"rid": "r2",
"out": {"meta_info": {"prompt_tokens": 200, "completion_tokens": 100}},
},
]
with tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False) as f:
for item in log_data:
f.write(json.dumps(item) + "\n")
f.flush()
requests = load_from_request_logger(f.name)
self.assertEqual(len(requests), 2)
self.assertEqual(requests[0].request_id, "r1")
self.assertEqual(requests[0].input_len, 100)
self.assertEqual(requests[1].input_len, 200)
def test_empty_file(self):
with tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False) as f:
f.write("")
f.flush()
self.assertEqual(len(load_from_request_logger(f.name)), 0)
class TestDataSynthesis(CustomTestCase):
def test_generate_basic(self):
requests = generate_random_requests(
num_requests=10, input_len=100, output_len=50
)
self.assertEqual(len(requests), 10)
for req in requests:
self.assertEqual(req.input_len, 100)
self.assertEqual(req.output_len, 50)
def test_generate_with_range_ratio(self):
requests = generate_random_requests(
num_requests=100, input_len=100, output_len=50, range_ratio=0.5, seed=42
)
for req in requests:
self.assertGreaterEqual(req.input_len, 50)
self.assertLessEqual(req.input_len, 100)
def test_generate_with_seed(self):
r1 = generate_random_requests(
num_requests=10, input_len=100, output_len=50, range_ratio=0.5, seed=42
)
r2 = generate_random_requests(
num_requests=10, input_len=100, output_len=50, range_ratio=0.5, seed=42
)
for a, b in zip(r1, r2):
self.assertEqual(a.input_len, b.input_len)
def test_generate_gsp_basic(self):
requests = generate_gsp_requests(
num_groups=4,
prompts_per_group=3,
system_prompt_len=100,
question_len=50,
output_len=25,
seed=42,
)
self.assertEqual(len(requests), 12)
for req in requests:
self.assertIsNotNone(req.group_id)
self.assertEqual(req.prefix_len, 100)
self.assertEqual(req.input_len, 150)
self.assertEqual(req.output_len, 25)
def test_generate_gsp_group_assignment(self):
requests = generate_gsp_requests(
num_groups=3,
prompts_per_group=2,
system_prompt_len=100,
question_len=50,
output_len=25,
seed=42,
)
group_counts = {}
for req in requests:
group_counts[req.group_id] = group_counts.get(req.group_id, 0) + 1
self.assertEqual(len(group_counts), 3)
for count in group_counts.values():
self.assertEqual(count, 2)
def test_generate_gsp_with_range_ratio(self):
requests = generate_gsp_requests(
num_groups=4,
prompts_per_group=5,
system_prompt_len=100,
question_len=50,
output_len=25,
range_ratio=0.5,
seed=42,
)
for req in requests:
self.assertGreaterEqual(req.prefix_len, 50)
self.assertLessEqual(req.prefix_len, 100)
self.assertGreaterEqual(req.input_len - req.prefix_len, 25)
self.assertLessEqual(req.input_len - req.prefix_len, 50)
def test_generate_gsp_shuffled(self):
requests = generate_gsp_requests(
num_groups=4,
prompts_per_group=10,
system_prompt_len=100,
question_len=50,
output_len=25,
seed=42,
)
group_ids = [req.group_id for req in requests]
is_sorted = all(
group_ids[i] <= group_ids[i + 1] for i in range(len(group_ids) - 1)
)
self.assertFalse(is_sorted)
class TestSimulator(CustomTestCase):
def test_basic_run(self):
requests = [
SimRequest(request_id=f"r{i}", input_len=10, output_len=5)
for i in range(10)
]
sim = Simulator(
num_gpus_per_engine=2,
router=RoundRobinRouter(num_gpus=2),
scheduler=FIFOScheduler(),
recorders=[
BatchSizeBalancednessRecorder(),
AttentionComputeBalancednessRecorder(),
],
max_total_tokens=100,
)
result = sim.run(requests)
self.assertIsInstance(result, SimulationResult)
self.assertIn("batch_size_balancedness_mean", result.summary)
self.assertGreater(len(result.step_records), 0)
def test_all_requests_complete(self):
requests = [
SimRequest(request_id=f"r{i}", input_len=10, output_len=3) for i in range(4)
]
sim = Simulator(
num_gpus_per_engine=2,
router=RoundRobinRouter(num_gpus=2),
scheduler=FIFOScheduler(),
max_total_tokens=10000,
)
sim.run(requests)
for gpu in sim.gpu_states:
self.assertEqual(len(gpu.pending_requests), 0)
self.assertEqual(len(gpu.running_requests), 0)
def test_empty_requests(self):
sim = Simulator(
num_gpus_per_engine=2,
router=RoundRobinRouter(num_gpus=2),
scheduler=FIFOScheduler(),
)
result = sim.run([])
self.assertEqual(result.summary, {})
self.assertEqual(len(result.step_records), 0)
def test_step_records(self):
requests = [
SimRequest(request_id=f"r{i}", input_len=10, output_len=3) for i in range(4)
]
sim = Simulator(
num_gpus_per_engine=2,
router=RoundRobinRouter(num_gpus=2),
scheduler=FIFOScheduler(),
max_total_tokens=10000,
)
result = sim.run(requests)
self.assertGreater(len(result.step_records), 0)
for record in result.step_records:
self.assertIsInstance(record, StepRecord)
self.assertIn(record.gpu_id, [0, 1])
self.assertEqual(len([r for r in result.step_records if r.step == 0]), 2)
def test_preemption_due_to_token_growth(self):
requests = [
SimRequest(request_id="r0", input_len=50, output_len=10),
SimRequest(request_id="r1", input_len=50, output_len=10),
]
sim = Simulator(
num_gpus_per_engine=1,
router=RoundRobinRouter(num_gpus=1),
scheduler=FIFOScheduler(),
max_total_tokens=110,
)
result = sim.run(requests)
found_preemption = False
for record in result.step_records:
if record.running_count == 1 and record.pending_count == 1:
found_preemption = True
break
self.assertTrue(
found_preemption, "Expected preemption to occur due to token growth"
)
# ==================== E2E Tests ====================
class TestCLI(CustomTestCase):
def _run_cli(self, *args):
return subprocess.run(
[sys.executable, "-m", "sglang.srt.debug_utils.schedule_simulator", *args],
capture_output=True,
text=True,
)
def _assert_output_contains(self, output: str, expected_lines: str):
for line in expected_lines.strip().split("\n"):
self.assertIn(line, output)
def test_cli_basic(self):
log_data = [
{
"event": "request.finished",
"rid": "r1",
"out": {"meta_info": {"prompt_tokens": 100, "completion_tokens": 50}},
},
{
"event": "request.finished",
"rid": "r2",
"out": {"meta_info": {"prompt_tokens": 200, "completion_tokens": 100}},
},
]
with tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False) as f:
for item in log_data:
f.write(json.dumps(item) + "\n")
input_file = f.name
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
output_file = f.name
result = self._run_cli(
"--input", input_file, "--num-gpus-per-engine", "2", "--output", output_file
)
self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}")
self.assertIn("Loaded 2 requests", result.stdout)
with open(output_file) as f:
self.assertIn("batch_size_balancedness_mean", json.load(f))
def test_cli_random_router(self):
log_data = [
{
"event": "request.finished",
"rid": "r1",
"out": {"meta_info": {"prompt_tokens": 100, "completion_tokens": 50}},
}
]
with tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False) as f:
for item in log_data:
f.write(json.dumps(item) + "\n")
input_file = f.name
result = self._run_cli("--input", input_file, "--router", "random")
self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}")
self.assertIn("router=random", result.stdout)
def test_e2e_sticky_router_group_locality(self):
result = self._run_cli(
"--synth-gsp",
"--synth-gsp-num-groups",
"1",
"--synth-gsp-prompts-per-group",
"4",
"--synth-gsp-system-prompt-len",
"10",
"--synth-gsp-question-len",
"10",
"--synth-gsp-output-len",
"2",
"--synth-seed",
"42",
"--num-gpus-per-engine",
"2",
"--router",
"sticky",
"--max-total-tokens",
"1000",
"--log-level",
"2",
)
self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}")
self.assertIn("R=4:", result.stdout)
self.assertIn("R=0:-", result.stdout)
def test_cli_synthetic(self):
result = self._run_cli(
"--synthetic",
"--synth-random-num-requests",
"100",
"--synth-random-input-len",
"512",
"--synth-random-output-len",
"128",
"--synth-random-range-ratio",
"0.5",
"--num-gpus-per-engine",
"4",
)
self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}")
self.assertIn("Generated 100 random requests", result.stdout)
def test_cli_log_level(self):
result = self._run_cli(
"--synthetic",
"--synth-random-num-requests",
"10",
"--synth-random-output-len",
"5",
"--num-gpus-per-engine",
"2",
"--log-level",
"1",
)
self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}")
self.assertIn("step=", result.stdout)
def test_e2e_simple_no_queuing(self):
result = self._run_cli(
"--synthetic",
"--synth-random-num-requests",
"4",
"--synth-random-input-len",
"10",
"--synth-random-output-len",
"2",
"--synth-random-range-ratio",
"1.0",
"--synth-seed",
"42",
"--num-gpus-per-engine",
"2",
"--max-total-tokens",
"10000",
"--log-level",
"2",
)
self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}")
self.assertIn(
"step=0 | GPU0[R=2:syn0,syn2 Q=0:-] | GPU1[R=2:syn1,syn3 Q=0:-]",
result.stdout,
)
self.assertIn(
"step=1 | GPU0[R=0:- Q=0:-] | GPU1[R=0:- Q=0:-]", result.stdout
)
self.assertIn("batch_size_balancedness_mean: 1.0000", result.stdout)
def test_e2e_queuing_due_to_token_limit(self):
result = self._run_cli(
"--synthetic",
"--synth-random-num-requests",
"4",
"--synth-random-input-len",
"100",
"--synth-random-output-len",
"3",
"--synth-random-range-ratio",
"1.0",
"--synth-seed",
"42",
"--num-gpus-per-engine",
"1",
"--max-total-tokens",
"210",
"--log-level",
"2",
)
self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}")
self._assert_output_contains(
result.stdout,
"""
step=0 | GPU0[R=2:syn0,syn1 Q=2:syn2,syn3]
step=1 | GPU0[R=2:syn0,syn1 Q=2:syn2,syn3]
step=2 | GPU0[R=0:- Q=2:syn2,syn3]
step=3 | GPU0[R=2:syn2,syn3 Q=0:-]
step=4 | GPU0[R=2:syn2,syn3 Q=0:-]
step=5 | GPU0[R=0:- Q=0:-]""",
)
def test_e2e_retraction_due_to_token_growth(self):
result = self._run_cli(
"--synthetic",
"--synth-random-num-requests",
"2",
"--synth-random-input-len",
"50",
"--synth-random-output-len",
"10",
"--synth-random-range-ratio",
"1.0",
"--synth-seed",
"42",
"--num-gpus-per-engine",
"1",
"--max-total-tokens",
"110",
"--log-level",
"2",
)
self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}")
self._assert_output_contains(
result.stdout,
"""
step=0 | GPU0[R=2:syn0,syn1 Q=0:-]
step=5 | GPU0[R=2:syn0,syn1 Q=0:-]
step=6 | GPU0[R=1:syn0 Q=1:syn1]
step=9 | GPU0[R=0:- Q=1:syn1]
step=10 | GPU0[R=1:syn1 Q=0:-]
step=13 | GPU0[R=0:- Q=0:-]""",
)
def test_cli_gsp_basic(self):
result = self._run_cli(
"--synth-gsp",
"--synth-gsp-num-groups",
"4",
"--synth-gsp-prompts-per-group",
"8",
"--synth-gsp-system-prompt-len",
"100",
"--synth-gsp-question-len",
"50",
"--synth-gsp-output-len",
"10",
"--synth-seed",
"42",
"--num-gpus-per-engine",
"2",
)
self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}")
self.assertIn("Generated 32 GSP requests", result.stdout)
self.assertIn("4 groups x 8 prompts", result.stdout)
def test_e2e_gsp_shared_prefix_enables_batching(self):
for has_long_prefix in [True, False]:
prefix_len, question_len = (50, 10) if has_long_prefix else (10, 50)
result = self._run_cli(
"--synth-gsp",
"--synth-gsp-num-groups",
"1",
"--synth-gsp-prompts-per-group",
"2",
"--synth-gsp-system-prompt-len",
str(prefix_len),
"--synth-gsp-question-len",
str(question_len),
"--synth-gsp-output-len",
"2",
"--synth-seed",
"42",
"--num-gpus-per-engine",
"1",
"--max-total-tokens",
"80",
"--log-level",
"2",
)
self.assertEqual(result.returncode, 0, f"CLI failed: {result.stderr}")
if has_long_prefix:
self.assertIn("R=2:", result.stdout)
else:
self.assertNotIn("R=2:", result.stdout)
class TestLargerScale(CustomTestCase):
def _run_main(self, *cli_args) -> SimulationResult:
parser = create_arg_parser()
args = parser.parse_args(cli_args)
return main(args)
def _assert_in_range(self, value, lo, hi, name):
self.assertGreaterEqual(value, lo, f"{name}={value} < {lo}")
self.assertLessEqual(value, hi, f"{name}={value} > {hi}")
def test_vanilla_workload_random_policy(self):
result = self._run_main(
"--synthetic",
"--synth-random-num-requests",
"500000",
"--synth-random-input-len",
"32000",
"--synth-random-output-len",
"2000",
"--synth-seed",
"42",
"--num-gpus-per-engine",
"8",
"--num-engines",
"250",
"--router",
"random",
"--max-total-tokens",
"2000000",
"--stop-criteria",
"exist_no_pending",
"--max-steps",
"1500",
)
self._assert_in_range(
result.summary["attention_compute_balancedness_mean"], 0.95, 1.0, "attn"
)
self._assert_in_range(
result.summary["batch_size_balancedness_mean"], 0.90, 0.98, "bs"
)
self._assert_in_range(result.summary["avg_batch_size"], 127, 141, "avg_bs")
def _run_gsp_workload(self, router: str) -> SimulationResult:
return self._run_main(
"--synth-gsp",
"--synth-gsp-num-groups",
"50000",
"--synth-gsp-prompts-per-group",
"100",
"--synth-gsp-system-prompt-len",
"31000",
"--synth-gsp-question-len",
"1000",
"--synth-gsp-output-len",
"8000",
"--synth-seed",
"42",
"--num-gpus-per-engine",
"8",
"--num-engines",
"250",
"--router",
router,
"--max-total-tokens",
"500000",
"--stop-criteria",
"exist_no_pending",
"--max-steps",
"1500",
)
def test_gsp_workload_random_policy(self):
result = self._run_gsp_workload("random")
self._assert_in_range(
result.summary["attention_compute_balancedness_mean"], 0.90, 0.97, "attn"
)
self._assert_in_range(
result.summary["batch_size_balancedness_mean"], 0.90, 0.97, "bs"
)
self._assert_in_range(result.summary["avg_batch_size"], 14, 17, "avg_bs")
def test_gsp_workload_sticky_policy(self):
result = self._run_gsp_workload("sticky")
self._assert_in_range(
result.summary["attention_compute_balancedness_mean"], 0.64, 0.71, "attn"
)
self._assert_in_range(
result.summary["batch_size_balancedness_mean"], 0.64, 0.71, "bs"
)
self._assert_in_range(result.summary["avg_batch_size"], 31, 36, "avg_bs")
if __name__ == "__main__":
unittest.main()