## Model Hooks SGLang supports attaching PyTorch forward hooks to specific submodules in the loaded model, configured entirely via `server_args` JSON. This is useful for: * Logging intermediate activations * Debugging model internals * Exporting hidden states to external tooling Hooks are attached once during `ModelRunner.initialize` and run on every forward pass. --- ### Configuration overview Hooks are configured via a `ServerArgs` field: ```python class ServerArgs: ... # For forward hooks forward_hooks: Optional[List[dict[str, Any]]] = None ```` In JSON form, a minimal configuration looks like: ```jsonc { "forward_hooks": [ { "name": "outer_linear_hooks", "target_modules": ["outer.0", "outer.1"], "hook_factory": "my_project.hooks:dummy_hook_factory", "config": { "tag": "outer-layer" } } ] } ``` #### Top-level fields * `forward_hooks` (optional list of objects) Each element is a hook spec describing: * Which modules to target * Which Python factory to call * What configuration to pass into that factory --- ### Hook spec schema Each entry in `forward_hooks` is a JSON object with the following shape: ```jsonc { "name": "optional-descriptive-name", "target_modules": ["pattern1", "pattern2", "..."], "hook_factory": "module.submodule:factory_name", "config": { "...": "arbitrary JSON" } } ``` #### `name` (optional) * Human-readable name for logging. * Used only in log messages such as: ```text Registered forward hook 'outer_linear_hooks' on outer.0 ``` #### `target_modules` (required) * List of **module name patterns** used to match entries in `model.named_modules()`. * Patterns are matched using `fnmatch.fnmatch`, so: * `"outer.0"` matches exactly `"outer.0"`. * `"outer.*"` matches `"outer.0"`, `"outer.1"`, `"outer.inner"`, etc. * `"outer.inner.*"` matches children under `outer.inner`. > If no modules match the given patterns, hook registration does **not** fail. > Instead, SGLang logs a warning and continues: > > ```text > No modules matched hook spec 'name' patterns=['...'] > ``` #### `hook_factory` (required) * String path to the Python factory function that creates the hook. * Supported formats: * `"package.module:factory_name"` * `"package.module.submodule.factory_name"` The path is resolved via: ```python def resolve_callable(path: Optional[str]) -> Optional[Callable]: if path is None: return None if ":" in path: module_name, fn_name = path.split(":", 1) else: parts = path.split(".") if len(parts) < 2: raise ValueError( f"Invalid hook callable path '{path}'. " "Expected 'module.submodule:factory' or 'module.submodule.factory'." ) *mod_parts, fn_name = parts module_name = ".".join(mod_parts) module = importlib.import_module(module_name) try: return getattr(module, fn_name) except AttributeError as e: raise AttributeError( f"Module '{module_name}' has no attribute '{fn_name}' " f"(from hook path '{path}')" ) from e ``` **Failure modes**: * If the path is malformed (not enough dots and no `:`), a `ValueError` is raised at startup. * If the module imports but the attribute is missing, an `AttributeError` is raised with a clear error message. * If the hook factory returns `None`, a warning is logged and no hook is registered for that spec (initialization continues). The first two cause initialization to fail fast with a descriptive error; the last one is non-fatal. #### `config` (optional) * Arbitrary JSON object. * Passed directly to the hook factory as a Python `dict`. * This lets you parameterize hook behavior from config (e.g. tags, log levels, sampling rates, etc.). --- ### Hook lifecycle and behavior Hooks are registered in `ModelRunner.initialize()`: ```python if server_args.forward_hooks: register_forward_hooks(self.model, server_args.forward_hooks) ``` The actual registration logic is implemented by `register_forward_hooks`: ```python def register_forward_hooks(model: nn.Module, hook_specs: List[dict[str, Any]]) -> None: """ hook_specs is a list of dicts from server_args.forward_hooks. Attaches forward hooks to the matching modules. """ name_to_module = dict(model.named_modules()) for spec in hook_specs: spec_name = spec.get("name", "") target_patterns = spec.get("target_modules", []) if not target_patterns: logger.warning( f"Hook spec '{spec_name}' has no 'target_modules', skipping" ) continue hook_factory_path = spec.get("hook_factory") if not hook_factory_path: logger.warning( f"Hook spec '{spec_name}' has no 'hook_factory', skipping" ) continue config = spec.get("config") or {} hook_factory = resolve_callable(hook_factory_path) hook = hook_factory(config) if hook_factory else None if hook is None: logger.warning( f"Hook factory '{hook_factory_path}' for spec '{spec_name}' " "returned None, not registering any hook" ) continue # Resolve patterns like "model.layers.*.mlp" matched = [] for name, module in name_to_module.items(): if any(fnmatch.fnmatch(name, pattern) for pattern in target_patterns): matched.append((name, module)) if not matched: logger.warning( f"No modules matched hook spec '{spec_name}' " f"patterns={target_patterns}" ) continue for module_name, module in matched: if hook: _ = module.register_forward_hook(hook) logger.info( f"Registered forward hook '{spec_name}' " f"on {module_name}" ) ``` Key points: * Hooks are **forward hooks only** (via `module.register_forward_hook`). * They are attached once at initialization. * Hook handles are currently not stored on `ModelRunner` (they cannot be removed later via this API). * Failure to match any modules is non-fatal; a warning is logged instead. * If a hook factory returns `None`, a warning is logged and that spec is skipped. --- ### Writing a hook factory A hook factory is a regular Python function: * Takes a `config: dict` (from JSON) * Returns a forward hook function with signature `(module, inputs, output)` Example: ```python HOOK_CALLS = [] def dummy_hook_factory(config): """Factory that returns a forward hook capturing a tag from config.""" tag = config.get("tag", "default") def hook(module, inputs, output): HOOK_CALLS.append( { "module_type": type(module).__name__, "tag": tag, "shape": tuple(output.shape), } ) return output # must return output if you don’t want to modify the tensor return hook ``` In JSON: ```jsonc { "forward_hooks": [ { "name": "capture_outer", "target_modules": ["outer.0", "outer.1"], "hook_factory": "my_project.hooks:dummy_hook_factory", "config": { "tag": "outer" } } ] } ``` This will: * Resolve `my_project.hooks:dummy_hook_factory` to a Python callable. * Call it with `config = {"tag": "outer"}`. * Use the returned hook for all modules matching `outer.0` and `outer.1`. * Append metadata about each call to `HOOK_CALLS`. --- ### Summary * Define `forward_hooks` as a list of specs in `ServerArgs` to turn on the feature. * Each spec: * selects modules via `target_modules` (glob patterns over `model.named_modules()`), * points to a hook factory via `hook_factory`, * passes arbitrary `config` into that factory. * Hook factories are resolved via `resolve_callable`, which supports `module:factory` and `module.submodule.factory`. * Hooks are standard PyTorch forward hooks, attached once at startup and invoked on every forward pass. * Misconfiguration is either: * **fatal and explicit** (bad path / missing attribute), or * **non-fatal with clear warnings** (no targets matched, or factory returned `None`).