alverciito
upload safetensors and refactor research files
dbd79bd
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# START OF FILE #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
import logging
import torch
from .functions import REG_FUNCTION_MAP
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
class HookMonitor:
"""
Monitors forward activations and backward gradients of a PyTorch model by
registering hooks on all its submodules. The monitor computes per-layer
statistics defined in `REG_FUNCTION_MAP`, accumulating them during forward
and backward passes, and provides normalized results at the end.
This class is designed to be lightweight, safe (uses no_grad for activation
hooks), and usable as a context manager to automate attachment and cleanup
of hooks.
----------------------------------------
Core Behavior
----------------------------------------
- During the forward pass:
• A forward hook receives (module, input, output).
• The activation tensor is detached and cast to float.
• For each registered metric in REG_FUNCTION_MAP, if its watcher flag
is enabled, the metric is computed and accumulated.
• A gradient hook is registered on the output tensor so that gradient
statistics can also be collected during backpropagation.
- During backpropagation:
• The gradient hook receives the gradient tensor for the activation.
• Any metric marked as `grad_<metric>` in the watcher dictionary will be
applied to the gradient tensor and accumulated.
- Statistics:
• For each metric, the class tracks both the accumulated value and a
"/valid/" counter.
• `get_stats()` returns normalized statistics (sum / valid_count) for
each metric per layer.
----------------------------------------
Parameters
----------------------------------------
model : torch.nn.Module
The model whose modules will be monitored. All submodules returned by
`model.named_modules()` will receive a forward hook.
watcher : dict
A dictionary mapping metric names to boolean flags. Keys must match the
names used in `REG_FUNCTION_MAP`. Example:
{
"mean": True,
"std": True,
"grad_mean": True
}
Metrics not enabled here will not be computed.
logger : logging.Logger
A Logger used to report errors, debugging information, and warnings.
----------------------------------------
Attributes
----------------------------------------
stats : dict
Nested dictionary storing accumulated statistics per layer. Normalized
results are returned by `get_stats()`.
handles : list
A List of hook handles returned by `register_forward_hook`. These are
stored to later remove all hooks safely.
----------------------------------------
Usage Example
----------------------------------------
>>> model: torch.nn.Module
>>> watcher: dict[str, bool]
>>> logger: logging.Logger
>>> x: torch.Tensor
>>> loss: torch.nn.Module # Loss
>>> monitor = HookMonitor(model, watcher, logger)
>>> monitor.attach()
>>> output = model(x)
>>> loss.backward()
>>> stats = monitor.get_stats()
>>> monitor.remove()
Or using a context manager:
>>> with HookMonitor(model, watcher, logger) as monitor:
... output = model(x)
... loss.backward()
>>> stats = monitor.get_stats()
----------------------------------------
Notes
----------------------------------------
- The gradient hook is attached to the activation tensor (module output),
not to model parameters.
- No gradients are tracked during forward hooks thanks to @torch.no_grad().
- The monitor does not interfere with the training process: it only reads
activations and gradients.
- Missing '/valid/' counters trigger an error log and skip normalization for
that metric.
"""
def __init__(self, model: torch.nn.Module, watcher: dict, logger: logging.Logger):
"""
Initialize a HookMonitor instance to track activation and gradient
statistics across all modules of a PyTorch model.
This constructor does not attach any hooks yet; it simply stores the
monitoring configuration. Hooks are registered only when `attach()` or
the context manager (`with HookMonitor(...)`) is used.
Parameters
----------
model : torch.nn.Module
The model whose internal modules will be monitored. Every submodule
returned by `model.named_modules()` will receive a forward hook.
watcher : dict
Dictionary of boolean flags controlling which statistics should be
computed. Keys must match the names in `REG_FUNCTION_MAP`.
Example:
{
"mean": True,
"std": False,
"grad_mean": True
}
Any metric not enabled here will not be computed during execution.
logger : logging.Logger
Logging instance used for reporting errors, debug messages and
warnings during monitoring operations.
Attributes Initialized
----------------------
model : torch.nn.Module
Stored reference to the monitored model.
watcher : dict
The watcher configuration controlling metric activation.
stats : dict
Internal dictionary used to accumulate statistics across all layers.
handles : list
A List of hook handles created when calling `.attach()`. Each handle
is later used to safely remove hooks with `.remove()`.
Notes
-----
- No hooks are installed at construction time.
- The monitor becomes active only after calling `.attach()` or entering
a `with` block.
"""
self.logger: logging.Logger = logger
self.model: torch.nn.Module = model
self.watcher: dict = watcher
self.stats: dict = dict()
self.handles: list = list()
def _build_hook(self, name):
@torch.no_grad()
def hook(*args):
_, _, act = args
if torch.is_tensor(act):
act_detached = act.detach().float()
s = self.stats.setdefault(name, {})
# Call functions:
for function_name, compute_function in REG_FUNCTION_MAP.items():
if self.watcher.get(function_name, False) and not function_name.startswith('grad_'):
value = compute_function(act_detached, ...)
if value is not None:
s[function_name] = s.get(function_name, 0.0) + value
s[function_name + '/valid/'] = s.get(function_name + '/valid/', 0.0) + 1
# Grad hook:
def grad_hook(grad):
gd = grad.detach().float()
# Call functions:
for gd_function_name, gd_compute_function in REG_FUNCTION_MAP.items():
if self.watcher.get('grad_' + gd_function_name, False) and not gd_function_name.startswith('grad_'):
gd_function_name = 'grad_' + gd_function_name
gd_value = gd_compute_function(gd, ...)
if gd_value is not None:
s[gd_function_name] = s.get(gd_function_name, 0.0) + gd_value
s[gd_function_name + '/valid/'] = s.get(gd_function_name + '/valid/', 0.0) + 1
if act.requires_grad:
act.register_hook(grad_hook)
return hook
def get_stats(self) -> dict:
"""
Get the statistics of the hooks.
:return: A dictionary with the statistics.
"""
stats = dict()
for layer_name, layer_stats in self.stats.items():
sub_stats = dict()
for key, item in layer_stats.items():
if '/valid/' not in key:
if key + '/valid/' in layer_stats:
sub_stats[key] = item / layer_stats[key + '/valid/']
else:
self.logger.error(f"Key {key} has no valid count, skipping normalization.")
sub_stats[key] = item
stats[layer_name] = sub_stats
return stats
def attach(self):
"""
Registers all the hooks in the model.
:return: The object.
"""
for name, module in self.model.named_modules():
h = module.register_forward_hook(self._build_hook(name))
self.handles.append(h)
return self
def clear(self):
"""
Clear stats' dictionary.
:return: Nothing
"""
self.stats.clear()
def remove(self):
"""
Remove all the hooks from the model.
:return: Nothing.
"""
for h in self.handles:
h.remove()
self.handles.clear()
def __enter__(self):
self.logger.debug("[Hooks] Attaching HookMonitor...")
return self.attach()
def __exit__(self, exc_type, exc_val, exc_tb):
self.logger.debug("[Hooks] Removing HookMonitor...")
self.remove()
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# END OF FILE #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #