| | """Self-contained subset of :mod:`circuit_sparsity.hook_utils` for inference builds. |
| | |
| | The full module has no exotic dependencies, but mirroring the definitions here |
| | keeps the trimmed :mod:`circuit_sparsity.inference.gpt` module hermetic and easy to vendor. The |
| | implementations below are copied with minor tweaks for readability so that code |
| | written against :func:`hook_recorder`, :func:`hook_namespace`, and |
| | :func:`torch_recompute_preserving_hook_context` behaves identically in both the |
| | training and inference configurations. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import re |
| | from contextlib import contextmanager |
| | from functools import partial |
| |
|
| | import torch |
| | import torch.utils.checkpoint |
| |
|
| |
|
| | class HookContext: |
| | """State container used by the hook helpers.""" |
| |
|
| | def __init__(self) -> None: |
| | self._reset() |
| | self.curintervtransformer = lambda x: x |
| |
|
| | def _reset(self) -> None: |
| | self.curcontext = None |
| | self.curname = "" |
| | self.curregex = None |
| | self.curinterventions = None |
| | self.save_grads = None |
| |
|
| | def _get_interventions(self): |
| | return self.curintervtransformer( |
| | self.curinterventions if self.curinterventions is not None else {} |
| | ) |
| |
|
| | @contextmanager |
| | def hook_recorder(self, regex: str = ".*", interventions=None, save_grads: bool = False): |
| | """Record tensors that pass through hooks matching ``regex``.""" |
| |
|
| | assert self.curcontext is None, "reentrancy not allowed!" |
| |
|
| | try: |
| | self.curcontext = {} |
| | self.curregex = re.compile(regex) |
| | self.curname = "" |
| | self.curinterventions = interventions |
| | self.save_grads = save_grads |
| |
|
| | yield self.curcontext |
| | finally: |
| | self._reset() |
| | get_context()._reset() |
| |
|
| | @contextmanager |
| | def hook_intervention_transform(self, intervention_transformer): |
| | oldintervention_transformer = self.curintervtransformer |
| |
|
| | def compose(f, g): |
| | return lambda x: f(g(x)) |
| |
|
| | self.curintervtransformer = compose( |
| | intervention_transformer, |
| | self.curintervtransformer, |
| | ) |
| |
|
| | try: |
| | yield |
| | finally: |
| | self.curintervtransformer = oldintervention_transformer |
| |
|
| | @contextmanager |
| | def hook_namespace(self, name: str): |
| | """Temporarily push ``name`` onto the hook namespace stack.""" |
| |
|
| | oldname = self.curname |
| | self.curname = self.curname + name + "." |
| |
|
| | try: |
| | yield |
| | finally: |
| | self.curname = oldname |
| |
|
| | def hook_save(self, name: str, tensor: torch.Tensor) -> torch.Tensor: |
| | """Optionally record ``tensor`` using the current namespace.""" |
| |
|
| | curinterventions = self._get_interventions() |
| | if curinterventions is not None: |
| | key = self.curname + name |
| | if key in curinterventions: |
| | tensor = curinterventions[key](tensor) |
| |
|
| | if self.curcontext is not None and self.curregex.match(self.curname + name): |
| | self.curcontext[self.curname + name] = tensor |
| |
|
| | if self.curcontext is not None and self.save_grads and tensor.requires_grad: |
| |
|
| | class _Grad(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, input_tensor): |
| | return input_tensor |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | self.curcontext[self.curname + name + ".grad"] = grad_output |
| | return grad_output |
| |
|
| | if self.curregex.match(self.curname + name + ".grad"): |
| | tensor = _Grad.apply(tensor) |
| |
|
| | return tensor |
| |
|
| |
|
| | def set_context(new_context: HookContext) -> None: |
| | global context |
| | context = new_context |
| |
|
| |
|
| | def get_context() -> HookContext: |
| | global context |
| | return context |
| |
|
| |
|
| | def torch_recompute_preserving_hook_context(f, *xs, use_reentrant=None): |
| | """Wrapper around :func:`torch.utils.checkpoint` that propagates hooks.""" |
| |
|
| | oldcontext = get_context() |
| | curcontext = HookContext() |
| | curcontext.curcontext = ( |
| | dict(oldcontext.curcontext) if oldcontext.curcontext is not None else None |
| | ) |
| | curcontext.curregex = oldcontext.curregex |
| | curcontext.curname = oldcontext.curname |
| | curcontext.curinterventions = ( |
| | dict(oldcontext.curinterventions) if oldcontext.curinterventions is not None else None |
| | ) |
| | curcontext.save_grads = oldcontext.save_grads |
| |
|
| | is_recompute = False |
| |
|
| | def _f(curcontext: HookContext, *xs): |
| | initcontext = get_context() |
| | nonlocal is_recompute |
| |
|
| | set_context(curcontext) |
| | try: |
| | res = f(*xs) |
| |
|
| | if not is_recompute and oldcontext.curcontext is not None: |
| | oldcontext.curcontext |= curcontext.curcontext |
| | finally: |
| | set_context(initcontext) |
| | is_recompute = True |
| | return res |
| |
|
| | res = torch.utils.checkpoint.checkpoint( |
| | partial(_f, curcontext), *xs, use_reentrant=use_reentrant |
| | ) |
| |
|
| | return res |
| |
|
| |
|
| | context = HookContext() |
| |
|
| |
|
| | def hook_recorder(*a, **k): |
| | return get_context().hook_recorder(*a, **k) |
| |
|
| |
|
| | def hook_namespace(*a, **k): |
| | return get_context().hook_namespace(*a, **k) |
| |
|
| |
|
| | def hook_save(*a, **k): |
| | return get_context().hook_save(*a, **k) |
| |
|
| |
|
| | def hook_intervention_transform(*a, **k): |
| | return get_context().hook_intervention_transform(*a, **k) |