|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import torch |
|
|
from .functions import REG_FUNCTION_MAP |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HookMonitor: |
|
|
""" |
|
|
Monitors forward activations and backward gradients of a PyTorch model by |
|
|
registering hooks on all its submodules. The monitor computes per-layer |
|
|
statistics defined in `REG_FUNCTION_MAP`, accumulating them during forward |
|
|
and backward passes, and provides normalized results at the end. |
|
|
|
|
|
This class is designed to be lightweight, safe (uses no_grad for activation |
|
|
hooks), and usable as a context manager to automate attachment and cleanup |
|
|
of hooks. |
|
|
|
|
|
---------------------------------------- |
|
|
Core Behavior |
|
|
---------------------------------------- |
|
|
- During the forward pass: |
|
|
• A forward hook receives (module, input, output). |
|
|
• The activation tensor is detached and cast to float. |
|
|
• For each registered metric in REG_FUNCTION_MAP, if its watcher flag |
|
|
is enabled, the metric is computed and accumulated. |
|
|
• A gradient hook is registered on the output tensor so that gradient |
|
|
statistics can also be collected during backpropagation. |
|
|
|
|
|
- During backpropagation: |
|
|
• The gradient hook receives the gradient tensor for the activation. |
|
|
• Any metric marked as `grad_<metric>` in the watcher dictionary will be |
|
|
applied to the gradient tensor and accumulated. |
|
|
|
|
|
- Statistics: |
|
|
• For each metric, the class tracks both the accumulated value and a |
|
|
"/valid/" counter. |
|
|
• `get_stats()` returns normalized statistics (sum / valid_count) for |
|
|
each metric per layer. |
|
|
|
|
|
---------------------------------------- |
|
|
Parameters |
|
|
---------------------------------------- |
|
|
model : torch.nn.Module |
|
|
The model whose modules will be monitored. All submodules returned by |
|
|
`model.named_modules()` will receive a forward hook. |
|
|
|
|
|
watcher : dict |
|
|
A dictionary mapping metric names to boolean flags. Keys must match the |
|
|
names used in `REG_FUNCTION_MAP`. Example: |
|
|
{ |
|
|
"mean": True, |
|
|
"std": True, |
|
|
"grad_mean": True |
|
|
} |
|
|
|
|
|
Metrics not enabled here will not be computed. |
|
|
|
|
|
logger : logging.Logger |
|
|
A Logger used to report errors, debugging information, and warnings. |
|
|
|
|
|
---------------------------------------- |
|
|
Attributes |
|
|
---------------------------------------- |
|
|
stats : dict |
|
|
Nested dictionary storing accumulated statistics per layer. Normalized |
|
|
results are returned by `get_stats()`. |
|
|
|
|
|
handles : list |
|
|
A List of hook handles returned by `register_forward_hook`. These are |
|
|
stored to later remove all hooks safely. |
|
|
|
|
|
---------------------------------------- |
|
|
Usage Example |
|
|
---------------------------------------- |
|
|
>>> model: torch.nn.Module |
|
|
>>> watcher: dict[str, bool] |
|
|
>>> logger: logging.Logger |
|
|
>>> x: torch.Tensor |
|
|
>>> loss: torch.nn.Module # Loss |
|
|
|
|
|
>>> monitor = HookMonitor(model, watcher, logger) |
|
|
>>> monitor.attach() |
|
|
>>> output = model(x) |
|
|
>>> loss.backward() |
|
|
>>> stats = monitor.get_stats() |
|
|
>>> monitor.remove() |
|
|
|
|
|
Or using a context manager: |
|
|
|
|
|
>>> with HookMonitor(model, watcher, logger) as monitor: |
|
|
... output = model(x) |
|
|
... loss.backward() |
|
|
>>> stats = monitor.get_stats() |
|
|
|
|
|
---------------------------------------- |
|
|
Notes |
|
|
---------------------------------------- |
|
|
- The gradient hook is attached to the activation tensor (module output), |
|
|
not to model parameters. |
|
|
- No gradients are tracked during forward hooks thanks to @torch.no_grad(). |
|
|
- The monitor does not interfere with the training process: it only reads |
|
|
activations and gradients. |
|
|
- Missing '/valid/' counters trigger an error log and skip normalization for |
|
|
that metric. |
|
|
|
|
|
""" |
|
|
def __init__(self, model: torch.nn.Module, watcher: dict, logger: logging.Logger): |
|
|
""" |
|
|
Initialize a HookMonitor instance to track activation and gradient |
|
|
statistics across all modules of a PyTorch model. |
|
|
|
|
|
This constructor does not attach any hooks yet; it simply stores the |
|
|
monitoring configuration. Hooks are registered only when `attach()` or |
|
|
the context manager (`with HookMonitor(...)`) is used. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
model : torch.nn.Module |
|
|
The model whose internal modules will be monitored. Every submodule |
|
|
returned by `model.named_modules()` will receive a forward hook. |
|
|
|
|
|
watcher : dict |
|
|
Dictionary of boolean flags controlling which statistics should be |
|
|
computed. Keys must match the names in `REG_FUNCTION_MAP`. |
|
|
Example: |
|
|
{ |
|
|
"mean": True, |
|
|
"std": False, |
|
|
"grad_mean": True |
|
|
} |
|
|
|
|
|
Any metric not enabled here will not be computed during execution. |
|
|
|
|
|
logger : logging.Logger |
|
|
Logging instance used for reporting errors, debug messages and |
|
|
warnings during monitoring operations. |
|
|
|
|
|
Attributes Initialized |
|
|
---------------------- |
|
|
model : torch.nn.Module |
|
|
Stored reference to the monitored model. |
|
|
|
|
|
watcher : dict |
|
|
The watcher configuration controlling metric activation. |
|
|
|
|
|
stats : dict |
|
|
Internal dictionary used to accumulate statistics across all layers. |
|
|
|
|
|
handles : list |
|
|
A List of hook handles created when calling `.attach()`. Each handle |
|
|
is later used to safely remove hooks with `.remove()`. |
|
|
|
|
|
Notes |
|
|
----- |
|
|
- No hooks are installed at construction time. |
|
|
- The monitor becomes active only after calling `.attach()` or entering |
|
|
a `with` block. |
|
|
""" |
|
|
self.logger: logging.Logger = logger |
|
|
self.model: torch.nn.Module = model |
|
|
self.watcher: dict = watcher |
|
|
self.stats: dict = dict() |
|
|
self.handles: list = list() |
|
|
|
|
|
def _build_hook(self, name): |
|
|
|
|
|
@torch.no_grad() |
|
|
def hook(*args): |
|
|
_, _, act = args |
|
|
|
|
|
if torch.is_tensor(act): |
|
|
act_detached = act.detach().float() |
|
|
s = self.stats.setdefault(name, {}) |
|
|
|
|
|
|
|
|
for function_name, compute_function in REG_FUNCTION_MAP.items(): |
|
|
if self.watcher.get(function_name, False) and not function_name.startswith('grad_'): |
|
|
value = compute_function(act_detached, ...) |
|
|
if value is not None: |
|
|
s[function_name] = s.get(function_name, 0.0) + value |
|
|
s[function_name + '/valid/'] = s.get(function_name + '/valid/', 0.0) + 1 |
|
|
|
|
|
|
|
|
def grad_hook(grad): |
|
|
gd = grad.detach().float() |
|
|
|
|
|
for gd_function_name, gd_compute_function in REG_FUNCTION_MAP.items(): |
|
|
if self.watcher.get('grad_' + gd_function_name, False) and not gd_function_name.startswith('grad_'): |
|
|
gd_function_name = 'grad_' + gd_function_name |
|
|
gd_value = gd_compute_function(gd, ...) |
|
|
if gd_value is not None: |
|
|
s[gd_function_name] = s.get(gd_function_name, 0.0) + gd_value |
|
|
s[gd_function_name + '/valid/'] = s.get(gd_function_name + '/valid/', 0.0) + 1 |
|
|
|
|
|
if act.requires_grad: |
|
|
act.register_hook(grad_hook) |
|
|
|
|
|
return hook |
|
|
|
|
|
def get_stats(self) -> dict: |
|
|
""" |
|
|
Get the statistics of the hooks. |
|
|
:return: A dictionary with the statistics. |
|
|
""" |
|
|
stats = dict() |
|
|
for layer_name, layer_stats in self.stats.items(): |
|
|
sub_stats = dict() |
|
|
for key, item in layer_stats.items(): |
|
|
if '/valid/' not in key: |
|
|
if key + '/valid/' in layer_stats: |
|
|
sub_stats[key] = item / layer_stats[key + '/valid/'] |
|
|
else: |
|
|
self.logger.error(f"Key {key} has no valid count, skipping normalization.") |
|
|
sub_stats[key] = item |
|
|
stats[layer_name] = sub_stats |
|
|
return stats |
|
|
|
|
|
def attach(self): |
|
|
""" |
|
|
Registers all the hooks in the model. |
|
|
:return: The object. |
|
|
""" |
|
|
for name, module in self.model.named_modules(): |
|
|
h = module.register_forward_hook(self._build_hook(name)) |
|
|
self.handles.append(h) |
|
|
return self |
|
|
|
|
|
def clear(self): |
|
|
""" |
|
|
Clear stats' dictionary. |
|
|
:return: Nothing |
|
|
""" |
|
|
self.stats.clear() |
|
|
|
|
|
def remove(self): |
|
|
""" |
|
|
Remove all the hooks from the model. |
|
|
:return: Nothing. |
|
|
""" |
|
|
for h in self.handles: |
|
|
h.remove() |
|
|
self.handles.clear() |
|
|
|
|
|
def __enter__(self): |
|
|
self.logger.debug("[Hooks] Attaching HookMonitor...") |
|
|
return self.attach() |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
self.logger.debug("[Hooks] Removing HookMonitor...") |
|
|
self.remove() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|