Hanrui / sglang /docs /advanced_features /forward_hooks.md
Lekr0's picture
Add files using upload-large-folder tool
a227c91 verified
## Model Hooks
SGLang supports attaching PyTorch forward hooks to specific submodules in the loaded model, configured entirely via `server_args` JSON.
This is useful for:
* Logging intermediate activations
* Debugging model internals
* Exporting hidden states to external tooling
Hooks are attached once during `ModelRunner.initialize` and run on every forward pass.
---
### Configuration overview
Hooks are configured via a `ServerArgs` field:
```python
class ServerArgs:
...
# For forward hooks
forward_hooks: Optional[List[dict[str, Any]]] = None
````
In JSON form, a minimal configuration looks like:
```jsonc
{
"forward_hooks": [
{
"name": "outer_linear_hooks",
"target_modules": ["outer.0", "outer.1"],
"hook_factory": "my_project.hooks:dummy_hook_factory",
"config": {
"tag": "outer-layer"
}
}
]
}
```
#### Top-level fields
* `forward_hooks` (optional list of objects)
Each element is a hook spec describing:
* Which modules to target
* Which Python factory to call
* What configuration to pass into that factory
---
### Hook spec schema
Each entry in `forward_hooks` is a JSON object with the following shape:
```jsonc
{
"name": "optional-descriptive-name",
"target_modules": ["pattern1", "pattern2", "..."],
"hook_factory": "module.submodule:factory_name",
"config": {
"...": "arbitrary JSON"
}
}
```
#### `name` (optional)
* Human-readable name for logging.
* Used only in log messages such as:
```text
Registered forward hook 'outer_linear_hooks' on outer.0
```
#### `target_modules` (required)
* List of **module name patterns** used to match entries in `model.named_modules()`.
* Patterns are matched using `fnmatch.fnmatch`, so:
* `"outer.0"` matches exactly `"outer.0"`.
* `"outer.*"` matches `"outer.0"`, `"outer.1"`, `"outer.inner"`, etc.
* `"outer.inner.*"` matches children under `outer.inner`.
> If no modules match the given patterns, hook registration does **not** fail.
> Instead, SGLang logs a warning and continues:
>
> ```text
> No modules matched hook spec 'name' patterns=['...']
> ```
#### `hook_factory` (required)
* String path to the Python factory function that creates the hook.
* Supported formats:
* `"package.module:factory_name"`
* `"package.module.submodule.factory_name"`
The path is resolved via:
```python
def resolve_callable(path: Optional[str]) -> Optional[Callable]:
if path is None:
return None
if ":" in path:
module_name, fn_name = path.split(":", 1)
else:
parts = path.split(".")
if len(parts) < 2:
raise ValueError(
f"Invalid hook callable path '{path}'. "
"Expected 'module.submodule:factory' or 'module.submodule.factory'."
)
*mod_parts, fn_name = parts
module_name = ".".join(mod_parts)
module = importlib.import_module(module_name)
try:
return getattr(module, fn_name)
except AttributeError as e:
raise AttributeError(
f"Module '{module_name}' has no attribute '{fn_name}' "
f"(from hook path '{path}')"
) from e
```
**Failure modes**:
* If the path is malformed (not enough dots and no `:`), a `ValueError` is raised at startup.
* If the module imports but the attribute is missing, an `AttributeError` is raised with a clear error message.
* If the hook factory returns `None`, a warning is logged and no hook is registered for that spec (initialization continues).
The first two cause initialization to fail fast with a descriptive error; the last one is non-fatal.
#### `config` (optional)
* Arbitrary JSON object.
* Passed directly to the hook factory as a Python `dict`.
* This lets you parameterize hook behavior from config (e.g. tags, log levels, sampling rates, etc.).
---
### Hook lifecycle and behavior
Hooks are registered in `ModelRunner.initialize()`:
```python
if server_args.forward_hooks:
register_forward_hooks(self.model, server_args.forward_hooks)
```
The actual registration logic is implemented by `register_forward_hooks`:
```python
def register_forward_hooks(model: nn.Module, hook_specs: List[dict[str, Any]]) -> None:
"""
hook_specs is a list of dicts from server_args.forward_hooks.
Attaches forward hooks to the matching modules.
"""
name_to_module = dict(model.named_modules())
for spec in hook_specs:
spec_name = spec.get("name", "")
target_patterns = spec.get("target_modules", [])
if not target_patterns:
logger.warning(
f"Hook spec '{spec_name}' has no 'target_modules', skipping"
)
continue
hook_factory_path = spec.get("hook_factory")
if not hook_factory_path:
logger.warning(
f"Hook spec '{spec_name}' has no 'hook_factory', skipping"
)
continue
config = spec.get("config") or {}
hook_factory = resolve_callable(hook_factory_path)
hook = hook_factory(config) if hook_factory else None
if hook is None:
logger.warning(
f"Hook factory '{hook_factory_path}' for spec '{spec_name}' "
"returned None, not registering any hook"
)
continue
# Resolve patterns like "model.layers.*.mlp"
matched = []
for name, module in name_to_module.items():
if any(fnmatch.fnmatch(name, pattern) for pattern in target_patterns):
matched.append((name, module))
if not matched:
logger.warning(
f"No modules matched hook spec '{spec_name}' "
f"patterns={target_patterns}"
)
continue
for module_name, module in matched:
if hook:
_ = module.register_forward_hook(hook)
logger.info(
f"Registered forward hook '{spec_name}' "
f"on {module_name}"
)
```
Key points:
* Hooks are **forward hooks only** (via `module.register_forward_hook`).
* They are attached once at initialization.
* Hook handles are currently not stored on `ModelRunner` (they cannot be removed later via this API).
* Failure to match any modules is non-fatal; a warning is logged instead.
* If a hook factory returns `None`, a warning is logged and that spec is skipped.
---
### Writing a hook factory
A hook factory is a regular Python function:
* Takes a `config: dict` (from JSON)
* Returns a forward hook function with signature `(module, inputs, output)`
Example:
```python
HOOK_CALLS = []
def dummy_hook_factory(config):
"""Factory that returns a forward hook capturing a tag from config."""
tag = config.get("tag", "default")
def hook(module, inputs, output):
HOOK_CALLS.append(
{
"module_type": type(module).__name__,
"tag": tag,
"shape": tuple(output.shape),
}
)
return output # must return output if you don’t want to modify the tensor
return hook
```
In JSON:
```jsonc
{
"forward_hooks": [
{
"name": "capture_outer",
"target_modules": ["outer.0", "outer.1"],
"hook_factory": "my_project.hooks:dummy_hook_factory",
"config": {
"tag": "outer"
}
}
]
}
```
This will:
* Resolve `my_project.hooks:dummy_hook_factory` to a Python callable.
* Call it with `config = {"tag": "outer"}`.
* Use the returned hook for all modules matching `outer.0` and `outer.1`.
* Append metadata about each call to `HOOK_CALLS`.
---
### Summary
* Define `forward_hooks` as a list of specs in `ServerArgs` to turn on the feature.
* Each spec:
* selects modules via `target_modules` (glob patterns over `model.named_modules()`),
* points to a hook factory via `hook_factory`,
* passes arbitrary `config` into that factory.
* Hook factories are resolved via `resolve_callable`, which supports `module:factory` and `module.submodule.factory`.
* Hooks are standard PyTorch forward hooks, attached once at startup and invoked on every forward pass.
* Misconfiguration is either:
* **fatal and explicit** (bad path / missing attribute), or
* **non-fatal with clear warnings** (no targets matched, or factory returned `None`).