File size: 8,366 Bytes
a227c91 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 | ## Model Hooks
SGLang supports attaching PyTorch forward hooks to specific submodules in the loaded model, configured entirely via `server_args` JSON.
This is useful for:
* Logging intermediate activations
* Debugging model internals
* Exporting hidden states to external tooling
Hooks are attached once during `ModelRunner.initialize` and run on every forward pass.
---
### Configuration overview
Hooks are configured via a `ServerArgs` field:
```python
class ServerArgs:
...
# For forward hooks
forward_hooks: Optional[List[dict[str, Any]]] = None
````
In JSON form, a minimal configuration looks like:
```jsonc
{
"forward_hooks": [
{
"name": "outer_linear_hooks",
"target_modules": ["outer.0", "outer.1"],
"hook_factory": "my_project.hooks:dummy_hook_factory",
"config": {
"tag": "outer-layer"
}
}
]
}
```
#### Top-level fields
* `forward_hooks` (optional list of objects)
Each element is a hook spec describing:
* Which modules to target
* Which Python factory to call
* What configuration to pass into that factory
---
### Hook spec schema
Each entry in `forward_hooks` is a JSON object with the following shape:
```jsonc
{
"name": "optional-descriptive-name",
"target_modules": ["pattern1", "pattern2", "..."],
"hook_factory": "module.submodule:factory_name",
"config": {
"...": "arbitrary JSON"
}
}
```
#### `name` (optional)
* Human-readable name for logging.
* Used only in log messages such as:
```text
Registered forward hook 'outer_linear_hooks' on outer.0
```
#### `target_modules` (required)
* List of **module name patterns** used to match entries in `model.named_modules()`.
* Patterns are matched using `fnmatch.fnmatch`, so:
* `"outer.0"` matches exactly `"outer.0"`.
* `"outer.*"` matches `"outer.0"`, `"outer.1"`, `"outer.inner"`, etc.
* `"outer.inner.*"` matches children under `outer.inner`.
> If no modules match the given patterns, hook registration does **not** fail.
> Instead, SGLang logs a warning and continues:
>
> ```text
> No modules matched hook spec 'name' patterns=['...']
> ```
#### `hook_factory` (required)
* String path to the Python factory function that creates the hook.
* Supported formats:
* `"package.module:factory_name"`
* `"package.module.submodule.factory_name"`
The path is resolved via:
```python
def resolve_callable(path: Optional[str]) -> Optional[Callable]:
if path is None:
return None
if ":" in path:
module_name, fn_name = path.split(":", 1)
else:
parts = path.split(".")
if len(parts) < 2:
raise ValueError(
f"Invalid hook callable path '{path}'. "
"Expected 'module.submodule:factory' or 'module.submodule.factory'."
)
*mod_parts, fn_name = parts
module_name = ".".join(mod_parts)
module = importlib.import_module(module_name)
try:
return getattr(module, fn_name)
except AttributeError as e:
raise AttributeError(
f"Module '{module_name}' has no attribute '{fn_name}' "
f"(from hook path '{path}')"
) from e
```
**Failure modes**:
* If the path is malformed (not enough dots and no `:`), a `ValueError` is raised at startup.
* If the module imports but the attribute is missing, an `AttributeError` is raised with a clear error message.
* If the hook factory returns `None`, a warning is logged and no hook is registered for that spec (initialization continues).
The first two cause initialization to fail fast with a descriptive error; the last one is non-fatal.
#### `config` (optional)
* Arbitrary JSON object.
* Passed directly to the hook factory as a Python `dict`.
* This lets you parameterize hook behavior from config (e.g. tags, log levels, sampling rates, etc.).
---
### Hook lifecycle and behavior
Hooks are registered in `ModelRunner.initialize()`:
```python
if server_args.forward_hooks:
register_forward_hooks(self.model, server_args.forward_hooks)
```
The actual registration logic is implemented by `register_forward_hooks`:
```python
def register_forward_hooks(model: nn.Module, hook_specs: List[dict[str, Any]]) -> None:
"""
hook_specs is a list of dicts from server_args.forward_hooks.
Attaches forward hooks to the matching modules.
"""
name_to_module = dict(model.named_modules())
for spec in hook_specs:
spec_name = spec.get("name", "")
target_patterns = spec.get("target_modules", [])
if not target_patterns:
logger.warning(
f"Hook spec '{spec_name}' has no 'target_modules', skipping"
)
continue
hook_factory_path = spec.get("hook_factory")
if not hook_factory_path:
logger.warning(
f"Hook spec '{spec_name}' has no 'hook_factory', skipping"
)
continue
config = spec.get("config") or {}
hook_factory = resolve_callable(hook_factory_path)
hook = hook_factory(config) if hook_factory else None
if hook is None:
logger.warning(
f"Hook factory '{hook_factory_path}' for spec '{spec_name}' "
"returned None, not registering any hook"
)
continue
# Resolve patterns like "model.layers.*.mlp"
matched = []
for name, module in name_to_module.items():
if any(fnmatch.fnmatch(name, pattern) for pattern in target_patterns):
matched.append((name, module))
if not matched:
logger.warning(
f"No modules matched hook spec '{spec_name}' "
f"patterns={target_patterns}"
)
continue
for module_name, module in matched:
if hook:
_ = module.register_forward_hook(hook)
logger.info(
f"Registered forward hook '{spec_name}' "
f"on {module_name}"
)
```
Key points:
* Hooks are **forward hooks only** (via `module.register_forward_hook`).
* They are attached once at initialization.
* Hook handles are currently not stored on `ModelRunner` (they cannot be removed later via this API).
* Failure to match any modules is non-fatal; a warning is logged instead.
* If a hook factory returns `None`, a warning is logged and that spec is skipped.
---
### Writing a hook factory
A hook factory is a regular Python function:
* Takes a `config: dict` (from JSON)
* Returns a forward hook function with signature `(module, inputs, output)`
Example:
```python
HOOK_CALLS = []
def dummy_hook_factory(config):
"""Factory that returns a forward hook capturing a tag from config."""
tag = config.get("tag", "default")
def hook(module, inputs, output):
HOOK_CALLS.append(
{
"module_type": type(module).__name__,
"tag": tag,
"shape": tuple(output.shape),
}
)
return output # must return output if you don’t want to modify the tensor
return hook
```
In JSON:
```jsonc
{
"forward_hooks": [
{
"name": "capture_outer",
"target_modules": ["outer.0", "outer.1"],
"hook_factory": "my_project.hooks:dummy_hook_factory",
"config": {
"tag": "outer"
}
}
]
}
```
This will:
* Resolve `my_project.hooks:dummy_hook_factory` to a Python callable.
* Call it with `config = {"tag": "outer"}`.
* Use the returned hook for all modules matching `outer.0` and `outer.1`.
* Append metadata about each call to `HOOK_CALLS`.
---
### Summary
* Define `forward_hooks` as a list of specs in `ServerArgs` to turn on the feature.
* Each spec:
* selects modules via `target_modules` (glob patterns over `model.named_modules()`),
* points to a hook factory via `hook_factory`,
* passes arbitrary `config` into that factory.
* Hook factories are resolved via `resolve_callable`, which supports `module:factory` and `module.submodule.factory`.
* Hooks are standard PyTorch forward hooks, attached once at startup and invoked on every forward pass.
* Misconfiguration is either:
* **fatal and explicit** (bad path / missing attribute), or
* **non-fatal with clear warnings** (no targets matched, or factory returned `None`).
|