Hanrui / sglang /test /registered /lora /test_lora_overlap_loading.py
Lekr0's picture
Add files using upload-large-folder tool
61ba51e verified
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import unittest
from typing import cast
from unittest.mock import MagicMock, patch
from torch.cuda import Event as CudaEvent
from torch.cuda import Stream as CudaStream
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.lora.lora_overlap_loader import LoRAOverlapLoader, LoRAOverlapLoadStatus
from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci
from sglang.test.lora_utils import (
CI_MULTI_LORA_MODELS,
run_lora_batch_splitting_equivalence_test,
)
from sglang.test.test_utils import CustomTestCase
register_cuda_ci(est_time=75, suite="stage-b-test-large-1-gpu")
register_amd_ci(est_time=75, suite="stage-b-test-small-1-gpu-amd")
class TestLoRAOverlapLoading(CustomTestCase):
def test_ci_lora_models_batch_splitting(self):
run_lora_batch_splitting_equivalence_test(
CI_MULTI_LORA_MODELS, enable_lora_overlap_loading=True
)
class TestLoRAOverlapLoaderUnitTests(CustomTestCase):
mock_lora_manager: MagicMock
mock_stream: MagicMock
mock_stream_context: MagicMock
mock_device_module: MagicMock
mock_torch: MagicMock
def setUp(self):
self.torch_patcher = patch("sglang.srt.lora.lora_overlap_loader.torch")
self.mock_torch = self.torch_patcher.start()
self.mock_device_module = MagicMock()
self.mock_stream = MagicMock(spec=CudaStream)
self.mock_stream_context = MagicMock()
self.mock_event = MagicMock(spec=CudaEvent)
self.mock_device_module.Stream.return_value = self.mock_stream
self.mock_device_module.stream.return_value = self.mock_stream_context
self.mock_device_module.Event.return_value = self.mock_event
self.mock_torch.get_device_module.return_value = self.mock_device_module
self.mock_torch.cuda.current_stream.return_value = MagicMock(spec=CudaStream)
self.mock_lora_manager = MagicMock(spec=LoRAManager)
self.mock_lora_manager.device = "cuda:0"
self.mock_lora_manager.validate_lora_batch.return_value = True
def tearDown(self):
self.torch_patcher.stop()
def _create_loader(self) -> LoRAOverlapLoader:
return LoRAOverlapLoader(cast(LoRAManager, self.mock_lora_manager))
def _create_mock_event(self, query_return: bool = False) -> MagicMock:
event = MagicMock(spec=CudaEvent)
event.query.return_value = query_return
return event
def test_full_lifecycle_single_lora_load(self):
loader = self._create_loader()
# Initially not loaded
status = loader._check_overlap_load_status("lora_A")
self.assertEqual(status, LoRAOverlapLoadStatus.NOT_LOADED)
# First call starts async load, returns False
result = loader.try_overlap_load_lora("lora_A", running_loras=set())
self.assertFalse(result)
self.assertIn("lora_A", loader.lora_to_overlap_load_event)
self.mock_lora_manager.fetch_new_loras.assert_called_once_with(
{"lora_A"}, set()
)
# Simulate load still in progress - returns False, event persists
loader.lora_to_overlap_load_event["lora_A"].query.return_value = False
result = loader.try_overlap_load_lora("lora_A", running_loras=set())
self.assertFalse(result)
self.assertEqual(
loader._check_overlap_load_status("lora_A"), LoRAOverlapLoadStatus.LOADING
)
# Simulate load complete - returns True, event removed
loader.lora_to_overlap_load_event["lora_A"].query.return_value = True
result = loader.try_overlap_load_lora("lora_A", running_loras=set())
self.assertTrue(result)
self.assertNotIn("lora_A", loader.lora_to_overlap_load_event)
def test_capacity_constraints_block_new_loads(self):
loader = self._create_loader()
events = [self._create_mock_event() for _ in range(4)]
self.mock_device_module.Event.side_effect = events
# Load 3 loras successfully
for i in range(3):
self.assertTrue(
loader._try_start_overlap_load(f"lora_{i}", running_loras=set())
)
self.assertEqual(len(loader.lora_to_overlap_load_event), 3)
# Capacity full - new load blocked
self.mock_lora_manager.validate_lora_batch.return_value = False
self.mock_lora_manager.fetch_new_loras.reset_mock()
result = loader.try_overlap_load_lora("lora_3", running_loras=set())
self.assertFalse(result)
self.mock_lora_manager.fetch_new_loras.assert_not_called()
self.assertNotIn("lora_3", loader.lora_to_overlap_load_event)
# First lora completes, freeing capacity
loader.lora_to_overlap_load_event["lora_0"].query.return_value = True
self.assertEqual(
loader._check_overlap_load_status("lora_0"), LoRAOverlapLoadStatus.LOADED
)
# Now new load succeeds
self.mock_lora_manager.validate_lora_batch.return_value = True
self.assertTrue(loader._try_start_overlap_load("lora_3", running_loras=set()))
def test_validation_includes_pending_and_running_loras(self):
loader = self._create_loader()
events = [self._create_mock_event() for _ in range(5)]
self.mock_device_module.Event.side_effect = events
# Start pending loads
loader._try_start_overlap_load("pending_1", running_loras=set())
loader._try_start_overlap_load("pending_2", running_loras=set())
# Load new lora with running_loras
self.mock_lora_manager.validate_lora_batch.reset_mock()
running = {"running_1", "running_2"}
loader.try_overlap_load_lora("new_lora", running_loras=running)
# Validation should include: pending + running + new
call_args = self.mock_lora_manager.validate_lora_batch.call_args[0][0]
expected = {"pending_1", "pending_2", "running_1", "running_2", "new_lora"}
self.assertEqual(call_args, expected)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")