import argparse import json import torch import torch.nn as nn from sglang.srt.model_executor.hook_manager import register_forward_hooks from sglang.srt.server_args import ServerArgs from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.test_utils import CustomTestCase register_cuda_ci(est_time=6, suite="stage-b-test-small-1-gpu") register_amd_ci(est_time=10, suite="stage-b-test-small-1-gpu-amd") HOOK_CALLS = [] def dummy_hook_factory(config): """Factory that returns a forward hook capturing a tag from config.""" tag = config.get("tag", "default") def hook(module, inputs, output): HOOK_CALLS.append( { "module_type": type(module).__name__, "tag": tag, "shape": tuple(output.shape), } ) return output return hook class TinyModel(nn.Module): def __init__(self): super().__init__() self.inner = nn.Sequential( nn.Linear(4, 2), nn.ReLU(), ) self.outer = nn.Sequential( nn.Linear(4, 4), nn.ReLU(), self.inner, ) def forward(self, x): return self.outer(x) class TestAttachHooks(CustomTestCase): """Tests for register_forward_hooks / resolve_callable integration.""" def setUp(self): HOOK_CALLS.clear() def test_hook_is_attached(self): """Hook from a factory string is registered and fired.""" hook_specs = [ { "target_modules": ["outer.0", "outer.1"], "hook_factory": "test_model_hooks:dummy_hook_factory", "config": {"tag": "forward-ok"}, }, { "target_modules": ["inner.*"], "hook_factory": "test_model_hooks:dummy_hook_factory", "config": {"tag": "forward-ok"}, }, ] model = TinyModel() register_forward_hooks(model, hook_specs) x = torch.randn(3, 4) _ = model(x) self.assertEqual( len(HOOK_CALLS), 4, "Forward hook was not called correct number of times", ) tags = {call["tag"] for call in HOOK_CALLS} self.assertIn("forward-ok", tags) def test_no_matching_modules_does_not_crash(self): """Hook spec with no matching modules should not crash.""" model = TinyModel() hook_specs = [ { "name": "no_match", "target_modules": ["does_not_exist.*"], "hook_factory": "test_model_hooks:dummy_hook_factory", "config": {"tag": "unused"}, } ] register_forward_hooks(model, hook_specs) x = torch.randn(3, 4) _ = model(x) # No hooks should have fired self.assertEqual(len(HOOK_CALLS), 0) def test_cli_hooks_reach_model(self): """ Ensure that when hooks are provided via CLI, they are parsed into ServerArgs, passed to register_forward_hooks, and actually run during a forward pass. """ parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) hooks_spec = [ { "name": "outer_and_inner_from_cli", "target_modules": ["outer.0", "outer.1", "inner.*"], "hook_factory": "test_model_hooks:dummy_hook_factory", "config": {"tag": "cli-hook"}, } ] cli_args = [ "--model-path", "Qwen/Qwen2-7B-Instruct", # Dummy value; not used in this test "--forward-hooks", json.dumps(hooks_spec), ] args = parser.parse_args(cli_args) server_args = ServerArgs.from_cli_args(args) self.assertEqual(server_args.forward_hooks, hooks_spec) model = TinyModel() register_forward_hooks(model, server_args.forward_hooks) x = torch.randn(3, 4) _ = model(x) # We expect hooks on outer.0, outer.1, inner.0, inner.1 => 4 calls self.assertEqual( len(HOOK_CALLS), 4, "CLI-configured hooks did not fire expected number of times", ) tags = {call["tag"] for call in HOOK_CALLS} self.assertEqual(tags, {"cli-hook"}) if __name__ == "__main__": pass # unittest.main()