File size: 8,366 Bytes
a227c91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
## Model Hooks

SGLang supports attaching PyTorch forward hooks to specific submodules in the loaded model, configured entirely via `server_args` JSON.

This is useful for:

* Logging intermediate activations
* Debugging model internals
* Exporting hidden states to external tooling

Hooks are attached once during `ModelRunner.initialize` and run on every forward pass.

---

### Configuration overview

Hooks are configured via a `ServerArgs` field:

```python
class ServerArgs:
    ...
    # For forward hooks
    forward_hooks: Optional[List[dict[str, Any]]] = None
````

In JSON form, a minimal configuration looks like:

```jsonc
{
  "forward_hooks": [
    {
      "name": "outer_linear_hooks",
      "target_modules": ["outer.0", "outer.1"],
      "hook_factory": "my_project.hooks:dummy_hook_factory",
      "config": {
        "tag": "outer-layer"
      }
    }
  ]
}
```

#### Top-level fields

* `forward_hooks` (optional list of objects)
  Each element is a hook spec describing:

  * Which modules to target
  * Which Python factory to call
  * What configuration to pass into that factory

---

### Hook spec schema

Each entry in `forward_hooks` is a JSON object with the following shape:

```jsonc
{
  "name": "optional-descriptive-name",
  "target_modules": ["pattern1", "pattern2", "..."],
  "hook_factory": "module.submodule:factory_name",
  "config": {
    "...": "arbitrary JSON"
  }
}
```

#### `name` (optional)

* Human-readable name for logging.
* Used only in log messages such as:

  ```text
  Registered forward hook 'outer_linear_hooks' on outer.0
  ```

#### `target_modules` (required)

* List of **module name patterns** used to match entries in `model.named_modules()`.
* Patterns are matched using `fnmatch.fnmatch`, so:

  * `"outer.0"` matches exactly `"outer.0"`.
  * `"outer.*"` matches `"outer.0"`, `"outer.1"`, `"outer.inner"`, etc.
  * `"outer.inner.*"` matches children under `outer.inner`.

> If no modules match the given patterns, hook registration does **not** fail.
> Instead, SGLang logs a warning and continues:
>
> ```text
> No modules matched hook spec 'name' patterns=['...']
> ```

#### `hook_factory` (required)

* String path to the Python factory function that creates the hook.
* Supported formats:

  * `"package.module:factory_name"`
  * `"package.module.submodule.factory_name"`

The path is resolved via:

```python
def resolve_callable(path: Optional[str]) -> Optional[Callable]:
    if path is None:
        return None

    if ":" in path:
        module_name, fn_name = path.split(":", 1)
    else:
        parts = path.split(".")
        if len(parts) < 2:
            raise ValueError(
                f"Invalid hook callable path '{path}'. "
                "Expected 'module.submodule:factory' or 'module.submodule.factory'."
            )
        *mod_parts, fn_name = parts
        module_name = ".".join(mod_parts)

    module = importlib.import_module(module_name)
    try:
        return getattr(module, fn_name)
    except AttributeError as e:
        raise AttributeError(
            f"Module '{module_name}' has no attribute '{fn_name}' "
            f"(from hook path '{path}')"
        ) from e
```

**Failure modes**:

* If the path is malformed (not enough dots and no `:`), a `ValueError` is raised at startup.
* If the module imports but the attribute is missing, an `AttributeError` is raised with a clear error message.
* If the hook factory returns `None`, a warning is logged and no hook is registered for that spec (initialization continues).

The first two cause initialization to fail fast with a descriptive error; the last one is non-fatal.

#### `config` (optional)

* Arbitrary JSON object.
* Passed directly to the hook factory as a Python `dict`.
* This lets you parameterize hook behavior from config (e.g. tags, log levels, sampling rates, etc.).

---

### Hook lifecycle and behavior

Hooks are registered in `ModelRunner.initialize()`:

```python
if server_args.forward_hooks:
    register_forward_hooks(self.model, server_args.forward_hooks)
```

The actual registration logic is implemented by `register_forward_hooks`:

```python
def register_forward_hooks(model: nn.Module, hook_specs: List[dict[str, Any]]) -> None:
    """
    hook_specs is a list of dicts from server_args.forward_hooks.
    Attaches forward hooks to the matching modules.
    """
    name_to_module = dict(model.named_modules())

    for spec in hook_specs:
        spec_name = spec.get("name", "")
        target_patterns = spec.get("target_modules", [])
        if not target_patterns:
            logger.warning(
                f"Hook spec '{spec_name}' has no 'target_modules', skipping"
            )
            continue

        hook_factory_path = spec.get("hook_factory")
        if not hook_factory_path:
            logger.warning(
                f"Hook spec '{spec_name}' has no 'hook_factory', skipping"
            )
            continue

        config = spec.get("config") or {}
        hook_factory = resolve_callable(hook_factory_path)

        hook = hook_factory(config) if hook_factory else None
        if hook is None:
            logger.warning(
                f"Hook factory '{hook_factory_path}' for spec '{spec_name}' "
                "returned None, not registering any hook"
            )
            continue

        # Resolve patterns like "model.layers.*.mlp"
        matched = []
        for name, module in name_to_module.items():
            if any(fnmatch.fnmatch(name, pattern) for pattern in target_patterns):
                matched.append((name, module))

        if not matched:
            logger.warning(
                f"No modules matched hook spec '{spec_name}' "
                f"patterns={target_patterns}"
            )
            continue

        for module_name, module in matched:
            if hook:
                _ = module.register_forward_hook(hook)
                logger.info(
                    f"Registered forward hook '{spec_name}' "
                    f"on {module_name}"
                )
```

Key points:

* Hooks are **forward hooks only** (via `module.register_forward_hook`).
* They are attached once at initialization.
* Hook handles are currently not stored on `ModelRunner` (they cannot be removed later via this API).
* Failure to match any modules is non-fatal; a warning is logged instead.
* If a hook factory returns `None`, a warning is logged and that spec is skipped.

---

### Writing a hook factory

A hook factory is a regular Python function:

* Takes a `config: dict` (from JSON)
* Returns a forward hook function with signature `(module, inputs, output)`

Example:

```python
HOOK_CALLS = []

def dummy_hook_factory(config):
    """Factory that returns a forward hook capturing a tag from config."""
    tag = config.get("tag", "default")

    def hook(module, inputs, output):
        HOOK_CALLS.append(
            {
                "module_type": type(module).__name__,
                "tag": tag,
                "shape": tuple(output.shape),
            }
        )
        return output  # must return output if you don’t want to modify the tensor

    return hook
```

In JSON:

```jsonc
{
  "forward_hooks": [
    {
      "name": "capture_outer",
      "target_modules": ["outer.0", "outer.1"],
      "hook_factory": "my_project.hooks:dummy_hook_factory",
      "config": {
        "tag": "outer"
      }
    }
  ]
}
```

This will:

* Resolve `my_project.hooks:dummy_hook_factory` to a Python callable.
* Call it with `config = {"tag": "outer"}`.
* Use the returned hook for all modules matching `outer.0` and `outer.1`.
* Append metadata about each call to `HOOK_CALLS`.

---

### Summary

* Define `forward_hooks` as a list of specs in `ServerArgs` to turn on the feature.

* Each spec:

  * selects modules via `target_modules` (glob patterns over `model.named_modules()`),
  * points to a hook factory via `hook_factory`,
  * passes arbitrary `config` into that factory.

* Hook factories are resolved via `resolve_callable`, which supports `module:factory` and `module.submodule.factory`.

* Hooks are standard PyTorch forward hooks, attached once at startup and invoked on every forward pass.

* Misconfiguration is either:

  * **fatal and explicit** (bad path / missing attribute), or
  * **non-fatal with clear warnings** (no targets matched, or factory returned `None`).