File size: 4,443 Bytes
61ba51e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import argparse
import json
import torch
import torch.nn as nn
from sglang.srt.model_executor.hook_manager import register_forward_hooks
from sglang.srt.server_args import ServerArgs
from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci
from sglang.test.test_utils import CustomTestCase
register_cuda_ci(est_time=6, suite="stage-b-test-small-1-gpu")
register_amd_ci(est_time=10, suite="stage-b-test-small-1-gpu-amd")
HOOK_CALLS = []
def dummy_hook_factory(config):
"""Factory that returns a forward hook capturing a tag from config."""
tag = config.get("tag", "default")
def hook(module, inputs, output):
HOOK_CALLS.append(
{
"module_type": type(module).__name__,
"tag": tag,
"shape": tuple(output.shape),
}
)
return output
return hook
class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.inner = nn.Sequential(
nn.Linear(4, 2),
nn.ReLU(),
)
self.outer = nn.Sequential(
nn.Linear(4, 4),
nn.ReLU(),
self.inner,
)
def forward(self, x):
return self.outer(x)
class TestAttachHooks(CustomTestCase):
"""Tests for register_forward_hooks / resolve_callable integration."""
def setUp(self):
HOOK_CALLS.clear()
def test_hook_is_attached(self):
"""Hook from a factory string is registered and fired."""
hook_specs = [
{
"target_modules": ["outer.0", "outer.1"],
"hook_factory": "test_model_hooks:dummy_hook_factory",
"config": {"tag": "forward-ok"},
},
{
"target_modules": ["inner.*"],
"hook_factory": "test_model_hooks:dummy_hook_factory",
"config": {"tag": "forward-ok"},
},
]
model = TinyModel()
register_forward_hooks(model, hook_specs)
x = torch.randn(3, 4)
_ = model(x)
self.assertEqual(
len(HOOK_CALLS),
4,
"Forward hook was not called correct number of times",
)
tags = {call["tag"] for call in HOOK_CALLS}
self.assertIn("forward-ok", tags)
def test_no_matching_modules_does_not_crash(self):
"""Hook spec with no matching modules should not crash."""
model = TinyModel()
hook_specs = [
{
"name": "no_match",
"target_modules": ["does_not_exist.*"],
"hook_factory": "test_model_hooks:dummy_hook_factory",
"config": {"tag": "unused"},
}
]
register_forward_hooks(model, hook_specs)
x = torch.randn(3, 4)
_ = model(x)
# No hooks should have fired
self.assertEqual(len(HOOK_CALLS), 0)
def test_cli_hooks_reach_model(self):
"""
Ensure that when hooks are provided via CLI, they are parsed into
ServerArgs, passed to register_forward_hooks, and actually
run during a forward pass.
"""
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
hooks_spec = [
{
"name": "outer_and_inner_from_cli",
"target_modules": ["outer.0", "outer.1", "inner.*"],
"hook_factory": "test_model_hooks:dummy_hook_factory",
"config": {"tag": "cli-hook"},
}
]
cli_args = [
"--model-path",
"Qwen/Qwen2-7B-Instruct", # Dummy value; not used in this test
"--forward-hooks",
json.dumps(hooks_spec),
]
args = parser.parse_args(cli_args)
server_args = ServerArgs.from_cli_args(args)
self.assertEqual(server_args.forward_hooks, hooks_spec)
model = TinyModel()
register_forward_hooks(model, server_args.forward_hooks)
x = torch.randn(3, 4)
_ = model(x)
# We expect hooks on outer.0, outer.1, inner.0, inner.1 => 4 calls
self.assertEqual(
len(HOOK_CALLS),
4,
"CLI-configured hooks did not fire expected number of times",
)
tags = {call["tag"] for call in HOOK_CALLS}
self.assertEqual(tags, {"cli-hook"})
if __name__ == "__main__":
pass
# unittest.main()
|