| import argparse |
| import json |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from sglang.srt.model_executor.hook_manager import register_forward_hooks |
| from sglang.srt.server_args import ServerArgs |
| from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci |
| from sglang.test.test_utils import CustomTestCase |
|
|
| register_cuda_ci(est_time=6, suite="stage-b-test-small-1-gpu") |
| register_amd_ci(est_time=10, suite="stage-b-test-small-1-gpu-amd") |
|
|
| HOOK_CALLS = [] |
|
|
|
|
| def dummy_hook_factory(config): |
| """Factory that returns a forward hook capturing a tag from config.""" |
| tag = config.get("tag", "default") |
|
|
| def hook(module, inputs, output): |
| HOOK_CALLS.append( |
| { |
| "module_type": type(module).__name__, |
| "tag": tag, |
| "shape": tuple(output.shape), |
| } |
| ) |
| return output |
|
|
| return hook |
|
|
|
|
| class TinyModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.inner = nn.Sequential( |
| nn.Linear(4, 2), |
| nn.ReLU(), |
| ) |
| self.outer = nn.Sequential( |
| nn.Linear(4, 4), |
| nn.ReLU(), |
| self.inner, |
| ) |
|
|
| def forward(self, x): |
| return self.outer(x) |
|
|
|
|
| class TestAttachHooks(CustomTestCase): |
| """Tests for register_forward_hooks / resolve_callable integration.""" |
|
|
| def setUp(self): |
| HOOK_CALLS.clear() |
|
|
| def test_hook_is_attached(self): |
| """Hook from a factory string is registered and fired.""" |
| hook_specs = [ |
| { |
| "target_modules": ["outer.0", "outer.1"], |
| "hook_factory": "test_model_hooks:dummy_hook_factory", |
| "config": {"tag": "forward-ok"}, |
| }, |
| { |
| "target_modules": ["inner.*"], |
| "hook_factory": "test_model_hooks:dummy_hook_factory", |
| "config": {"tag": "forward-ok"}, |
| }, |
| ] |
|
|
| model = TinyModel() |
| register_forward_hooks(model, hook_specs) |
|
|
| x = torch.randn(3, 4) |
| _ = model(x) |
|
|
| self.assertEqual( |
| len(HOOK_CALLS), |
| 4, |
| "Forward hook was not called correct number of times", |
| ) |
| tags = {call["tag"] for call in HOOK_CALLS} |
| self.assertIn("forward-ok", tags) |
|
|
| def test_no_matching_modules_does_not_crash(self): |
| """Hook spec with no matching modules should not crash.""" |
| model = TinyModel() |
| hook_specs = [ |
| { |
| "name": "no_match", |
| "target_modules": ["does_not_exist.*"], |
| "hook_factory": "test_model_hooks:dummy_hook_factory", |
| "config": {"tag": "unused"}, |
| } |
| ] |
|
|
| register_forward_hooks(model, hook_specs) |
|
|
| x = torch.randn(3, 4) |
| _ = model(x) |
|
|
| |
| self.assertEqual(len(HOOK_CALLS), 0) |
|
|
| def test_cli_hooks_reach_model(self): |
| """ |
| Ensure that when hooks are provided via CLI, they are parsed into |
| ServerArgs, passed to register_forward_hooks, and actually |
| run during a forward pass. |
| """ |
| parser = argparse.ArgumentParser() |
| ServerArgs.add_cli_args(parser) |
|
|
| hooks_spec = [ |
| { |
| "name": "outer_and_inner_from_cli", |
| "target_modules": ["outer.0", "outer.1", "inner.*"], |
| "hook_factory": "test_model_hooks:dummy_hook_factory", |
| "config": {"tag": "cli-hook"}, |
| } |
| ] |
|
|
| cli_args = [ |
| "--model-path", |
| "Qwen/Qwen2-7B-Instruct", |
| "--forward-hooks", |
| json.dumps(hooks_spec), |
| ] |
|
|
| args = parser.parse_args(cli_args) |
| server_args = ServerArgs.from_cli_args(args) |
|
|
| self.assertEqual(server_args.forward_hooks, hooks_spec) |
|
|
| model = TinyModel() |
| register_forward_hooks(model, server_args.forward_hooks) |
|
|
| x = torch.randn(3, 4) |
| _ = model(x) |
|
|
| |
| self.assertEqual( |
| len(HOOK_CALLS), |
| 4, |
| "CLI-configured hooks did not fire expected number of times", |
| ) |
|
|
| tags = {call["tag"] for call in HOOK_CALLS} |
| self.assertEqual(tags, {"cli-hook"}) |
|
|
|
|
| if __name__ == "__main__": |
| pass |
| |
|
|