| from types import ModuleType |
|
|
| import pytest |
|
|
| from sglang.srt.debug_utils.source_patcher.code_patcher import ( |
| CodePatcher, |
| _resolve_target, |
| patch_function, |
| ) |
| from sglang.srt.debug_utils.source_patcher.types import EditSpec, PatchSpec |
| from sglang.test.ci.ci_register import register_cpu_ci |
|
|
| register_cpu_ci(est_time=10, suite="default", nightly=True) |
|
|
| SAMPLE_MODULE_NAME = "_source_patcher_test_fixtures.sample_module" |
|
|
|
|
| class TestPatchFunction: |
| def test_basic_patch_changes_behavior(self, sample_module: ModuleType) -> None: |
| cls = sample_module.SampleClass |
| obj = cls() |
| assert obj.greet("world") == "hello world" |
|
|
| state = patch_function( |
| target=cls.greet, |
| edits=[ |
| EditSpec( |
| match='greeting = f"hello {name}"', |
| replacement='greeting = f"patched {name}"', |
| ) |
| ], |
| ) |
| try: |
| assert obj.greet("world") == "patched world" |
| finally: |
| state.restore() |
|
|
| assert obj.greet("world") == "hello world" |
|
|
| def test_globals_preserved_after_patch(self, sample_module: ModuleType) -> None: |
| cls = sample_module.SampleClass |
| obj = cls() |
| assert obj.uses_global() == "value=global_value" |
|
|
| state = patch_function( |
| target=cls.uses_global, |
| edits=[ |
| EditSpec( |
| match='return f"value={GLOBAL_VAR}"', |
| replacement='return f"patched_value={GLOBAL_VAR}"', |
| ) |
| ], |
| ) |
| try: |
| assert obj.uses_global() == "patched_value=global_value" |
| finally: |
| state.restore() |
|
|
| def test_function_identity_preserved(self, sample_module: ModuleType) -> None: |
| cls = sample_module.SampleClass |
| fn_id_before = id(cls.greet) |
|
|
| state = patch_function( |
| target=cls.greet, |
| edits=[ |
| EditSpec( |
| match='greeting = f"hello {name}"', |
| replacement='greeting = f"patched {name}"', |
| ) |
| ], |
| ) |
| try: |
| assert id(cls.greet) == fn_id_before |
| finally: |
| state.restore() |
|
|
| def test_patch_standalone_function(self, sample_module: ModuleType) -> None: |
| fn = sample_module.standalone_function |
| assert fn(2, 3) == 5 |
|
|
| state = patch_function( |
| target=fn, |
| edits=[ |
| EditSpec( |
| match="return a + b", |
| replacement="return a * b", |
| ) |
| ], |
| ) |
| try: |
| assert fn(2, 3) == 6 |
| finally: |
| state.restore() |
|
|
| assert fn(2, 3) == 5 |
|
|
| def test_patched_code_can_reference_global_variable( |
| self, sample_module: ModuleType |
| ) -> None: |
| """Replacement code that references a module-level global should work.""" |
| cls = sample_module.SampleClass |
| obj = cls() |
|
|
| state = patch_function( |
| target=cls.greet, |
| edits=[ |
| EditSpec( |
| match='greeting = f"hello {name}"', |
| replacement='greeting = f"{GLOBAL_VAR} {name}"', |
| ) |
| ], |
| ) |
| try: |
| assert obj.greet("world") == "global_value world" |
| finally: |
| state.restore() |
|
|
| def test_patched_code_can_call_another_class_method( |
| self, sample_module: ModuleType |
| ) -> None: |
| """Replacement code that calls HelperClass.format_value should work.""" |
| cls = sample_module.SampleClass |
| obj = cls() |
|
|
| state = patch_function( |
| target=cls.greet, |
| edits=[ |
| EditSpec( |
| match='greeting = f"hello {name}"', |
| replacement="greeting = HelperClass.format_value(name)", |
| ) |
| ], |
| ) |
| try: |
| assert obj.greet("world") == "[world]" |
| finally: |
| state.restore() |
|
|
| def test_patched_code_uses_helper_via_existing_method( |
| self, sample_module: ModuleType |
| ) -> None: |
| """The uses_helper method already calls HelperClass; verify it survives patching.""" |
| cls = sample_module.SampleClass |
| obj = cls() |
| assert obj.uses_helper("test") == "[test]" |
|
|
| state = patch_function( |
| target=cls.uses_helper, |
| edits=[ |
| EditSpec( |
| match="return HelperClass.format_value(value)", |
| replacement='return HelperClass.format_value("patched_" + value)', |
| ) |
| ], |
| ) |
| try: |
| assert obj.uses_helper("test") == "[patched_test]" |
| finally: |
| state.restore() |
|
|
| assert obj.uses_helper("test") == "[test]" |
|
|
|
|
| class TestResolveTarget: |
| def test_resolve_class_method(self, sample_module: ModuleType) -> None: |
| target = _resolve_target(f"{SAMPLE_MODULE_NAME}.SampleClass.greet") |
| assert target is sample_module.SampleClass.greet |
|
|
| def test_resolve_standalone_function(self, sample_module: ModuleType) -> None: |
| target = _resolve_target(f"{SAMPLE_MODULE_NAME}.standalone_function") |
| assert target is sample_module.standalone_function |
|
|
| def test_resolve_nonexistent_raises(self, sample_module: ModuleType) -> None: |
| with pytest.raises((ImportError, AttributeError)): |
| _resolve_target(f"{SAMPLE_MODULE_NAME}.NonexistentClass.method") |
|
|
|
|
| class TestCodePatcher: |
| def test_context_manager_patches_and_restores( |
| self, sample_module: ModuleType |
| ) -> None: |
| cls = sample_module.SampleClass |
| obj = cls() |
| assert obj.greet("world") == "hello world" |
|
|
| patches = [ |
| PatchSpec( |
| target=f"{SAMPLE_MODULE_NAME}.SampleClass.greet", |
| edits=[ |
| EditSpec( |
| match='greeting = f"hello {name}"', |
| replacement='greeting = f"ctx_patched {name}"', |
| ) |
| ], |
| ) |
| ] |
|
|
| with CodePatcher(patches=patches): |
| assert obj.greet("world") == "ctx_patched world" |
|
|
| assert obj.greet("world") == "hello world" |
|
|
| def test_context_manager_multiple_patches(self, sample_module: ModuleType) -> None: |
| cls = sample_module.SampleClass |
| obj = cls() |
|
|
| patches = [ |
| PatchSpec( |
| target=f"{SAMPLE_MODULE_NAME}.SampleClass.greet", |
| edits=[ |
| EditSpec( |
| match='greeting = f"hello {name}"', |
| replacement='greeting = f"p1 {name}"', |
| ) |
| ], |
| ), |
| PatchSpec( |
| target=f"{SAMPLE_MODULE_NAME}.SampleClass.compute", |
| edits=[ |
| EditSpec( |
| match="result = x * 2 + 1", |
| replacement="result = x * 100", |
| ) |
| ], |
| ), |
| ] |
|
|
| with CodePatcher(patches=patches): |
| assert obj.greet("world") == "p1 world" |
| assert obj.compute(5) == 500 |
|
|
| assert obj.greet("world") == "hello world" |
| assert obj.compute(5) == 11 |
|
|
| def test_restores_on_exception(self, sample_module: ModuleType) -> None: |
| cls = sample_module.SampleClass |
| obj = cls() |
|
|
| patches = [ |
| PatchSpec( |
| target=f"{SAMPLE_MODULE_NAME}.SampleClass.greet", |
| edits=[ |
| EditSpec( |
| match='greeting = f"hello {name}"', |
| replacement='greeting = f"err_patched {name}"', |
| ) |
| ], |
| ) |
| ] |
|
|
| with pytest.raises(RuntimeError): |
| with CodePatcher(patches=patches): |
| assert obj.greet("world") == "err_patched world" |
| raise RuntimeError("test error") |
|
|
| assert obj.greet("world") == "hello world" |
|
|