File size: 9,977 Bytes
dbd79bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 |
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# START OF FILE #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
import logging
import torch
from .functions import REG_FUNCTION_MAP
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
class HookMonitor:
"""
Monitors forward activations and backward gradients of a PyTorch model by
registering hooks on all its submodules. The monitor computes per-layer
statistics defined in `REG_FUNCTION_MAP`, accumulating them during forward
and backward passes, and provides normalized results at the end.
This class is designed to be lightweight, safe (uses no_grad for activation
hooks), and usable as a context manager to automate attachment and cleanup
of hooks.
----------------------------------------
Core Behavior
----------------------------------------
- During the forward pass:
• A forward hook receives (module, input, output).
• The activation tensor is detached and cast to float.
• For each registered metric in REG_FUNCTION_MAP, if its watcher flag
is enabled, the metric is computed and accumulated.
• A gradient hook is registered on the output tensor so that gradient
statistics can also be collected during backpropagation.
- During backpropagation:
• The gradient hook receives the gradient tensor for the activation.
• Any metric marked as `grad_<metric>` in the watcher dictionary will be
applied to the gradient tensor and accumulated.
- Statistics:
• For each metric, the class tracks both the accumulated value and a
"/valid/" counter.
• `get_stats()` returns normalized statistics (sum / valid_count) for
each metric per layer.
----------------------------------------
Parameters
----------------------------------------
model : torch.nn.Module
The model whose modules will be monitored. All submodules returned by
`model.named_modules()` will receive a forward hook.
watcher : dict
A dictionary mapping metric names to boolean flags. Keys must match the
names used in `REG_FUNCTION_MAP`. Example:
{
"mean": True,
"std": True,
"grad_mean": True
}
Metrics not enabled here will not be computed.
logger : logging.Logger
A Logger used to report errors, debugging information, and warnings.
----------------------------------------
Attributes
----------------------------------------
stats : dict
Nested dictionary storing accumulated statistics per layer. Normalized
results are returned by `get_stats()`.
handles : list
A List of hook handles returned by `register_forward_hook`. These are
stored to later remove all hooks safely.
----------------------------------------
Usage Example
----------------------------------------
>>> model: torch.nn.Module
>>> watcher: dict[str, bool]
>>> logger: logging.Logger
>>> x: torch.Tensor
>>> loss: torch.nn.Module # Loss
>>> monitor = HookMonitor(model, watcher, logger)
>>> monitor.attach()
>>> output = model(x)
>>> loss.backward()
>>> stats = monitor.get_stats()
>>> monitor.remove()
Or using a context manager:
>>> with HookMonitor(model, watcher, logger) as monitor:
... output = model(x)
... loss.backward()
>>> stats = monitor.get_stats()
----------------------------------------
Notes
----------------------------------------
- The gradient hook is attached to the activation tensor (module output),
not to model parameters.
- No gradients are tracked during forward hooks thanks to @torch.no_grad().
- The monitor does not interfere with the training process: it only reads
activations and gradients.
- Missing '/valid/' counters trigger an error log and skip normalization for
that metric.
"""
def __init__(self, model: torch.nn.Module, watcher: dict, logger: logging.Logger):
"""
Initialize a HookMonitor instance to track activation and gradient
statistics across all modules of a PyTorch model.
This constructor does not attach any hooks yet; it simply stores the
monitoring configuration. Hooks are registered only when `attach()` or
the context manager (`with HookMonitor(...)`) is used.
Parameters
----------
model : torch.nn.Module
The model whose internal modules will be monitored. Every submodule
returned by `model.named_modules()` will receive a forward hook.
watcher : dict
Dictionary of boolean flags controlling which statistics should be
computed. Keys must match the names in `REG_FUNCTION_MAP`.
Example:
{
"mean": True,
"std": False,
"grad_mean": True
}
Any metric not enabled here will not be computed during execution.
logger : logging.Logger
Logging instance used for reporting errors, debug messages and
warnings during monitoring operations.
Attributes Initialized
----------------------
model : torch.nn.Module
Stored reference to the monitored model.
watcher : dict
The watcher configuration controlling metric activation.
stats : dict
Internal dictionary used to accumulate statistics across all layers.
handles : list
A List of hook handles created when calling `.attach()`. Each handle
is later used to safely remove hooks with `.remove()`.
Notes
-----
- No hooks are installed at construction time.
- The monitor becomes active only after calling `.attach()` or entering
a `with` block.
"""
self.logger: logging.Logger = logger
self.model: torch.nn.Module = model
self.watcher: dict = watcher
self.stats: dict = dict()
self.handles: list = list()
def _build_hook(self, name):
@torch.no_grad()
def hook(*args):
_, _, act = args
if torch.is_tensor(act):
act_detached = act.detach().float()
s = self.stats.setdefault(name, {})
# Call functions:
for function_name, compute_function in REG_FUNCTION_MAP.items():
if self.watcher.get(function_name, False) and not function_name.startswith('grad_'):
value = compute_function(act_detached, ...)
if value is not None:
s[function_name] = s.get(function_name, 0.0) + value
s[function_name + '/valid/'] = s.get(function_name + '/valid/', 0.0) + 1
# Grad hook:
def grad_hook(grad):
gd = grad.detach().float()
# Call functions:
for gd_function_name, gd_compute_function in REG_FUNCTION_MAP.items():
if self.watcher.get('grad_' + gd_function_name, False) and not gd_function_name.startswith('grad_'):
gd_function_name = 'grad_' + gd_function_name
gd_value = gd_compute_function(gd, ...)
if gd_value is not None:
s[gd_function_name] = s.get(gd_function_name, 0.0) + gd_value
s[gd_function_name + '/valid/'] = s.get(gd_function_name + '/valid/', 0.0) + 1
if act.requires_grad:
act.register_hook(grad_hook)
return hook
def get_stats(self) -> dict:
"""
Get the statistics of the hooks.
:return: A dictionary with the statistics.
"""
stats = dict()
for layer_name, layer_stats in self.stats.items():
sub_stats = dict()
for key, item in layer_stats.items():
if '/valid/' not in key:
if key + '/valid/' in layer_stats:
sub_stats[key] = item / layer_stats[key + '/valid/']
else:
self.logger.error(f"Key {key} has no valid count, skipping normalization.")
sub_stats[key] = item
stats[layer_name] = sub_stats
return stats
def attach(self):
"""
Registers all the hooks in the model.
:return: The object.
"""
for name, module in self.model.named_modules():
h = module.register_forward_hook(self._build_hook(name))
self.handles.append(h)
return self
def clear(self):
"""
Clear stats' dictionary.
:return: Nothing
"""
self.stats.clear()
def remove(self):
"""
Remove all the hooks from the model.
:return: Nothing.
"""
for h in self.handles:
h.remove()
self.handles.clear()
def __enter__(self):
self.logger.debug("[Hooks] Attaching HookMonitor...")
return self.attach()
def __exit__(self, exc_type, exc_val, exc_tb):
self.logger.debug("[Hooks] Removing HookMonitor...")
self.remove()
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
# END OF FILE #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|