| """ |
| Test forward_split_prefill functionality. |
| |
| Usage: |
| python3 -m unittest test_forward_split_prefill.TestForwardSplitPrefill |
| or |
| python3 test_forward_split_prefill.py |
| """ |
|
|
| import unittest |
|
|
| import numpy as np |
| import torch |
|
|
| from sglang.bench_one_batch import TreeCacheNamespace |
| from sglang.srt.configs.model_config import ModelConfig |
| from sglang.srt.managers.schedule_batch import Req, ScheduleBatch |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch |
| from sglang.srt.model_executor.model_runner import ModelRunner |
| from sglang.srt.sampling.sampling_params import SamplingParams |
| from sglang.srt.server_args import PortArgs, ServerArgs |
| from sglang.srt.speculative.spec_info import SpeculativeAlgorithm |
| from sglang.srt.utils import get_device |
| from sglang.srt.utils.hf_transformers_utils import get_tokenizer |
| from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase |
|
|
|
|
| class TestForwardSplitPrefill(CustomTestCase): |
| """Test cases for forward_split_prefill functionality.""" |
|
|
| @classmethod |
| def setUpClass(cls): |
| """Set up the test environment once for all tests.""" |
| cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST |
| cls.tp_size = 1 |
| cls.device = get_device() |
|
|
| |
| cls.server_args = ServerArgs( |
| model_path=cls.model_path, |
| tokenizer_path=cls.model_path, |
| host="127.0.0.1", |
| disable_cuda_graph=True, |
| disable_hybrid_swa_memory=True, |
| port=30000, |
| tp_size=cls.tp_size, |
| mem_fraction_static=0.8, |
| trust_remote_code=True, |
| ) |
|
|
| cls.port_args = PortArgs.init_new(cls.server_args) |
|
|
| |
| cls.model_config = ModelConfig.from_server_args(cls.server_args) |
| cls.model_runner = ModelRunner( |
| model_config=cls.model_config, |
| mem_fraction_static=cls.server_args.mem_fraction_static, |
| gpu_id=0, |
| tp_rank=0, |
| tp_size=cls.tp_size, |
| pp_rank=0, |
| pp_size=1, |
| nccl_port=cls.port_args.nccl_port, |
| server_args=cls.server_args, |
| moe_ep_rank=0, |
| moe_ep_size=1, |
| ) |
|
|
| cls.tokenizer = get_tokenizer( |
| cls.server_args.tokenizer_path, |
| tokenizer_mode=cls.server_args.tokenizer_mode, |
| trust_remote_code=cls.server_args.trust_remote_code, |
| ) |
|
|
| print( |
| f"Test with model: {cls.model_path}, num_hidden_layers: {cls.model_config.num_hidden_layers}" |
| ) |
|
|
| def prepare_test_batch(self, batch_size=2, input_len=128, is_split_prefill=True): |
| """Prepare a test batch for split prefill testing.""" |
| |
| input_ids = np.random.randint(10, 1000, (batch_size, input_len), dtype=np.int32) |
|
|
| sampling_params = SamplingParams( |
| temperature=0.0, |
| max_new_tokens=8, |
| ) |
|
|
| reqs = [] |
| for i in range(batch_size): |
| req = Req( |
| rid=i, |
| origin_input_text="", |
| origin_input_ids=list(input_ids[i]), |
| sampling_params=sampling_params, |
| ) |
| req.fill_ids = req.origin_input_ids |
| req.logprob_start_len = -1 |
| req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices)) |
| reqs.append(req) |
|
|
| |
| dummy_tree_cache = TreeCacheNamespace( |
| page_size=1, |
| device=self.model_runner.device, |
| token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, |
| ) |
|
|
| batch = ScheduleBatch.init_new( |
| reqs=reqs, |
| req_to_token_pool=self.model_runner.req_to_token_pool, |
| token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, |
| tree_cache=dummy_tree_cache, |
| model_config=self.model_config, |
| enable_overlap=False, |
| spec_algorithm=SpeculativeAlgorithm.NONE, |
| ) |
| if is_split_prefill: |
| batch.prepare_for_split_prefill() |
| else: |
| batch.prepare_for_extend() |
|
|
| |
| model_worker_batch = batch.get_model_worker_batch() |
| forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) |
|
|
| return forward_batch |
|
|
| def test_split_prefill_functionality(self): |
| """Test that split prefill can complete successfully.""" |
| print("\n=== Testing split prefill functionality ===") |
|
|
| forward_batch = self.prepare_test_batch(batch_size=2, input_len=64) |
|
|
| |
| forward_batch.split_index = 0 |
|
|
| |
| num_layers = self.model_config.num_hidden_layers |
| chunk_size = max(1, num_layers // 4) |
|
|
| results = [] |
| split_count = 0 |
|
|
| while forward_batch.split_index < num_layers: |
| print( |
| f"Processing split {split_count}, split_index: {forward_batch.split_index}" |
| ) |
|
|
| result = self.model_runner.forward_split_prefill( |
| forward_batch=forward_batch, |
| reinit_attn_backend=(split_count == 0), |
| forward_count=chunk_size, |
| ) |
|
|
| results.append(result) |
| split_count += 1 |
|
|
| |
| expected_next_index = min(split_count * chunk_size, num_layers) |
| self.assertEqual(forward_batch.split_index, expected_next_index) |
|
|
| |
| self.assertIsNotNone(results[-1], "Final split should return logits") |
| print(f"Split prefill completed in {split_count} splits") |
|
|
| def test_split_prefill_vs_normal_prefill(self): |
| """Test that split prefill produces the same results as normal prefill.""" |
| print("\n=== Testing split prefill vs normal prefill consistency ===") |
|
|
| forward_batch_normal = self.prepare_test_batch( |
| batch_size=2, input_len=128, is_split_prefill=False |
| ) |
| forward_batch_split = self.prepare_test_batch( |
| batch_size=2, input_len=128, is_split_prefill=True |
| ) |
|
|
| |
| forward_batch_split.input_ids = forward_batch_normal.input_ids.clone() |
| forward_batch_split.positions = forward_batch_normal.positions.clone() |
|
|
| |
| print("Running normal extend (prefill)...") |
| normal_result = self.model_runner.forward_extend(forward_batch_normal) |
|
|
| |
| print("Running split prefill...") |
| num_layers = self.model_config.num_hidden_layers |
| chunk_size = max(1, num_layers // 3) |
|
|
| split_result = None |
|
|
| while forward_batch_split.split_index < num_layers: |
| result = self.model_runner.forward_split_prefill( |
| forward_batch=forward_batch_split, |
| forward_count=chunk_size, |
| ) |
| if result is not None: |
| split_result = result |
|
|
| |
| self.assertIsNotNone(normal_result, "Normal prefill should return result") |
| self.assertIsNotNone(split_result, "Split prefill should return result") |
|
|
| |
| self.assertEqual( |
| normal_result.next_token_logits.shape, |
| split_result.next_token_logits.shape, |
| "Logits shapes should match", |
| ) |
|
|
| |
| |
| torch.testing.assert_close( |
| normal_result.next_token_logits, |
| split_result.next_token_logits, |
| rtol=1e-3, |
| atol=1e-3, |
| msg="Split prefill and normal prefill should produce similar logits", |
| ) |
|
|
| print("✓ Split prefill and normal prefill produce consistent results") |
|
|
| def test_split_prefill_different_chunk_sizes(self): |
| """Test split prefill with different chunk sizes.""" |
| print("\n=== Testing split prefill with different chunk sizes ===") |
|
|
| num_layers = self.model_config.num_hidden_layers |
| chunk_sizes = [1, 2, max(1, num_layers // 2), num_layers] |
|
|
| |
| base_batch = self.prepare_test_batch(batch_size=1, input_len=16) |
| base_input_ids = base_batch.input_ids.clone() |
| base_positions = base_batch.positions.clone() |
|
|
| results = [] |
|
|
| for chunk_size in chunk_sizes: |
| if chunk_size > num_layers: |
| continue |
|
|
| print(f"Testing chunk size: {chunk_size}") |
|
|
| |
| forward_batch = self.prepare_test_batch(batch_size=1, input_len=16) |
| forward_batch.input_ids = base_input_ids.clone() |
| forward_batch.positions = base_positions.clone() |
| forward_batch.split_index = 0 |
|
|
| |
| split_result = None |
|
|
| while forward_batch.split_index < num_layers: |
| result = self.model_runner.forward_split_prefill( |
| forward_batch=forward_batch, |
| forward_count=chunk_size, |
| ) |
| if result is not None: |
| split_result = result |
|
|
| self.assertIsNotNone( |
| split_result, |
| f"Split prefill should succeed with chunk_size={chunk_size}", |
| ) |
| results.append(split_result) |
|
|
| |
| if len(results) > 1: |
| for i, result in enumerate(results[1:], 1): |
| torch.testing.assert_close( |
| results[0].next_token_logits, |
| result.next_token_logits, |
| rtol=1e-3, |
| atol=1e-3, |
| msg=f"Results with different chunk sizes should be identical (chunk_size {chunk_sizes[i]})", |
| ) |
|
|
| print("✓ All chunk sizes produce consistent results") |
|
|
| def test_split_prefill_edge_cases(self): |
| """Test edge cases for split prefill.""" |
| print("\n=== Testing split prefill edge cases ===") |
|
|
| |
| forward_batch = self.prepare_test_batch(batch_size=1, input_len=8) |
|
|
| |
| num_layers = self.model_config.num_hidden_layers |
| for layer_idx in range(num_layers): |
| result = self.model_runner.forward_split_prefill( |
| forward_batch=forward_batch, |
| reinit_attn_backend=(layer_idx == 0), |
| forward_count=1, |
| ) |
|
|
| if layer_idx == num_layers - 1: |
| |
| self.assertIsNotNone(result, "Last layer should return logits") |
| else: |
| |
| self.assertIsNone(result, f"Layer {layer_idx} should return None") |
|
|
| print("✓ Single layer processing works correctly") |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|