File size: 4,443 Bytes
61ba51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import json

import torch
import torch.nn as nn

from sglang.srt.model_executor.hook_manager import register_forward_hooks
from sglang.srt.server_args import ServerArgs
from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci
from sglang.test.test_utils import CustomTestCase

register_cuda_ci(est_time=6, suite="stage-b-test-small-1-gpu")
register_amd_ci(est_time=10, suite="stage-b-test-small-1-gpu-amd")

HOOK_CALLS = []


def dummy_hook_factory(config):
    """Factory that returns a forward hook capturing a tag from config."""
    tag = config.get("tag", "default")

    def hook(module, inputs, output):
        HOOK_CALLS.append(
            {
                "module_type": type(module).__name__,
                "tag": tag,
                "shape": tuple(output.shape),
            }
        )
        return output

    return hook


class TinyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.inner = nn.Sequential(
            nn.Linear(4, 2),
            nn.ReLU(),
        )
        self.outer = nn.Sequential(
            nn.Linear(4, 4),
            nn.ReLU(),
            self.inner,
        )

    def forward(self, x):
        return self.outer(x)


class TestAttachHooks(CustomTestCase):
    """Tests for register_forward_hooks / resolve_callable integration."""

    def setUp(self):
        HOOK_CALLS.clear()

    def test_hook_is_attached(self):
        """Hook from a factory string is registered and fired."""
        hook_specs = [
            {
                "target_modules": ["outer.0", "outer.1"],
                "hook_factory": "test_model_hooks:dummy_hook_factory",
                "config": {"tag": "forward-ok"},
            },
            {
                "target_modules": ["inner.*"],
                "hook_factory": "test_model_hooks:dummy_hook_factory",
                "config": {"tag": "forward-ok"},
            },
        ]

        model = TinyModel()
        register_forward_hooks(model, hook_specs)

        x = torch.randn(3, 4)
        _ = model(x)

        self.assertEqual(
            len(HOOK_CALLS),
            4,
            "Forward hook was not called correct number of times",
        )
        tags = {call["tag"] for call in HOOK_CALLS}
        self.assertIn("forward-ok", tags)

    def test_no_matching_modules_does_not_crash(self):
        """Hook spec with no matching modules should not crash."""
        model = TinyModel()
        hook_specs = [
            {
                "name": "no_match",
                "target_modules": ["does_not_exist.*"],
                "hook_factory": "test_model_hooks:dummy_hook_factory",
                "config": {"tag": "unused"},
            }
        ]

        register_forward_hooks(model, hook_specs)

        x = torch.randn(3, 4)
        _ = model(x)

        # No hooks should have fired
        self.assertEqual(len(HOOK_CALLS), 0)

    def test_cli_hooks_reach_model(self):
        """
        Ensure that when hooks are provided via CLI, they are parsed into
        ServerArgs, passed to register_forward_hooks, and actually
        run during a forward pass.
        """
        parser = argparse.ArgumentParser()
        ServerArgs.add_cli_args(parser)

        hooks_spec = [
            {
                "name": "outer_and_inner_from_cli",
                "target_modules": ["outer.0", "outer.1", "inner.*"],
                "hook_factory": "test_model_hooks:dummy_hook_factory",
                "config": {"tag": "cli-hook"},
            }
        ]

        cli_args = [
            "--model-path",
            "Qwen/Qwen2-7B-Instruct",  # Dummy value; not used in this test
            "--forward-hooks",
            json.dumps(hooks_spec),
        ]

        args = parser.parse_args(cli_args)
        server_args = ServerArgs.from_cli_args(args)

        self.assertEqual(server_args.forward_hooks, hooks_spec)

        model = TinyModel()
        register_forward_hooks(model, server_args.forward_hooks)

        x = torch.randn(3, 4)
        _ = model(x)

        # We expect hooks on outer.0, outer.1, inner.0, inner.1  => 4 calls
        self.assertEqual(
            len(HOOK_CALLS),
            4,
            "CLI-configured hooks did not fire expected number of times",
        )

        tags = {call["tag"] for call in HOOK_CALLS}
        self.assertEqual(tags, {"cli-hook"})


if __name__ == "__main__":
    pass
    # unittest.main()