# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # START OF FILE # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # import logging import torch from .functions import REG_FUNCTION_MAP # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # class HookMonitor: """ Monitors forward activations and backward gradients of a PyTorch model by registering hooks on all its submodules. The monitor computes per-layer statistics defined in `REG_FUNCTION_MAP`, accumulating them during forward and backward passes, and provides normalized results at the end. This class is designed to be lightweight, safe (uses no_grad for activation hooks), and usable as a context manager to automate attachment and cleanup of hooks. ---------------------------------------- Core Behavior ---------------------------------------- - During the forward pass: • A forward hook receives (module, input, output). • The activation tensor is detached and cast to float. • For each registered metric in REG_FUNCTION_MAP, if its watcher flag is enabled, the metric is computed and accumulated. • A gradient hook is registered on the output tensor so that gradient statistics can also be collected during backpropagation. - During backpropagation: • The gradient hook receives the gradient tensor for the activation. • Any metric marked as `grad_` in the watcher dictionary will be applied to the gradient tensor and accumulated. - Statistics: • For each metric, the class tracks both the accumulated value and a "/valid/" counter. • `get_stats()` returns normalized statistics (sum / valid_count) for each metric per layer. ---------------------------------------- Parameters ---------------------------------------- model : torch.nn.Module The model whose modules will be monitored. All submodules returned by `model.named_modules()` will receive a forward hook. watcher : dict A dictionary mapping metric names to boolean flags. Keys must match the names used in `REG_FUNCTION_MAP`. Example: { "mean": True, "std": True, "grad_mean": True } Metrics not enabled here will not be computed. logger : logging.Logger A Logger used to report errors, debugging information, and warnings. ---------------------------------------- Attributes ---------------------------------------- stats : dict Nested dictionary storing accumulated statistics per layer. Normalized results are returned by `get_stats()`. handles : list A List of hook handles returned by `register_forward_hook`. These are stored to later remove all hooks safely. ---------------------------------------- Usage Example ---------------------------------------- >>> model: torch.nn.Module >>> watcher: dict[str, bool] >>> logger: logging.Logger >>> x: torch.Tensor >>> loss: torch.nn.Module # Loss >>> monitor = HookMonitor(model, watcher, logger) >>> monitor.attach() >>> output = model(x) >>> loss.backward() >>> stats = monitor.get_stats() >>> monitor.remove() Or using a context manager: >>> with HookMonitor(model, watcher, logger) as monitor: ... output = model(x) ... loss.backward() >>> stats = monitor.get_stats() ---------------------------------------- Notes ---------------------------------------- - The gradient hook is attached to the activation tensor (module output), not to model parameters. - No gradients are tracked during forward hooks thanks to @torch.no_grad(). - The monitor does not interfere with the training process: it only reads activations and gradients. - Missing '/valid/' counters trigger an error log and skip normalization for that metric. """ def __init__(self, model: torch.nn.Module, watcher: dict, logger: logging.Logger): """ Initialize a HookMonitor instance to track activation and gradient statistics across all modules of a PyTorch model. This constructor does not attach any hooks yet; it simply stores the monitoring configuration. Hooks are registered only when `attach()` or the context manager (`with HookMonitor(...)`) is used. Parameters ---------- model : torch.nn.Module The model whose internal modules will be monitored. Every submodule returned by `model.named_modules()` will receive a forward hook. watcher : dict Dictionary of boolean flags controlling which statistics should be computed. Keys must match the names in `REG_FUNCTION_MAP`. Example: { "mean": True, "std": False, "grad_mean": True } Any metric not enabled here will not be computed during execution. logger : logging.Logger Logging instance used for reporting errors, debug messages and warnings during monitoring operations. Attributes Initialized ---------------------- model : torch.nn.Module Stored reference to the monitored model. watcher : dict The watcher configuration controlling metric activation. stats : dict Internal dictionary used to accumulate statistics across all layers. handles : list A List of hook handles created when calling `.attach()`. Each handle is later used to safely remove hooks with `.remove()`. Notes ----- - No hooks are installed at construction time. - The monitor becomes active only after calling `.attach()` or entering a `with` block. """ self.logger: logging.Logger = logger self.model: torch.nn.Module = model self.watcher: dict = watcher self.stats: dict = dict() self.handles: list = list() def _build_hook(self, name): @torch.no_grad() def hook(*args): _, _, act = args if torch.is_tensor(act): act_detached = act.detach().float() s = self.stats.setdefault(name, {}) # Call functions: for function_name, compute_function in REG_FUNCTION_MAP.items(): if self.watcher.get(function_name, False) and not function_name.startswith('grad_'): value = compute_function(act_detached, ...) if value is not None: s[function_name] = s.get(function_name, 0.0) + value s[function_name + '/valid/'] = s.get(function_name + '/valid/', 0.0) + 1 # Grad hook: def grad_hook(grad): gd = grad.detach().float() # Call functions: for gd_function_name, gd_compute_function in REG_FUNCTION_MAP.items(): if self.watcher.get('grad_' + gd_function_name, False) and not gd_function_name.startswith('grad_'): gd_function_name = 'grad_' + gd_function_name gd_value = gd_compute_function(gd, ...) if gd_value is not None: s[gd_function_name] = s.get(gd_function_name, 0.0) + gd_value s[gd_function_name + '/valid/'] = s.get(gd_function_name + '/valid/', 0.0) + 1 if act.requires_grad: act.register_hook(grad_hook) return hook def get_stats(self) -> dict: """ Get the statistics of the hooks. :return: A dictionary with the statistics. """ stats = dict() for layer_name, layer_stats in self.stats.items(): sub_stats = dict() for key, item in layer_stats.items(): if '/valid/' not in key: if key + '/valid/' in layer_stats: sub_stats[key] = item / layer_stats[key + '/valid/'] else: self.logger.error(f"Key {key} has no valid count, skipping normalization.") sub_stats[key] = item stats[layer_name] = sub_stats return stats def attach(self): """ Registers all the hooks in the model. :return: The object. """ for name, module in self.model.named_modules(): h = module.register_forward_hook(self._build_hook(name)) self.handles.append(h) return self def clear(self): """ Clear stats' dictionary. :return: Nothing """ self.stats.clear() def remove(self): """ Remove all the hooks from the model. :return: Nothing. """ for h in self.handles: h.remove() self.handles.clear() def __enter__(self): self.logger.debug("[Hooks] Attaching HookMonitor...") return self.attach() def __exit__(self, exc_type, exc_val, exc_tb): self.logger.debug("[Hooks] Removing HookMonitor...") self.remove() # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # END OF FILE # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #