File size: 9,977 Bytes
dbd79bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
#                         START OF FILE                             #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
import logging
import torch
from .functions import REG_FUNCTION_MAP


# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
#                                                                   #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
class HookMonitor:
    """
    Monitors forward activations and backward gradients of a PyTorch model by
    registering hooks on all its submodules. The monitor computes per-layer
    statistics defined in `REG_FUNCTION_MAP`, accumulating them during forward
    and backward passes, and provides normalized results at the end.

    This class is designed to be lightweight, safe (uses no_grad for activation
    hooks), and usable as a context manager to automate attachment and cleanup
    of hooks.

    ----------------------------------------
    Core Behavior
    ----------------------------------------
    - During the forward pass:
        • A forward hook receives (module, input, output).
        • The activation tensor is detached and cast to float.
        • For each registered metric in REG_FUNCTION_MAP, if its watcher flag
          is enabled, the metric is computed and accumulated.
        • A gradient hook is registered on the output tensor so that gradient
          statistics can also be collected during backpropagation.

    - During backpropagation:
        • The gradient hook receives the gradient tensor for the activation.
        • Any metric marked as `grad_<metric>` in the watcher dictionary will be
          applied to the gradient tensor and accumulated.

    - Statistics:
        • For each metric, the class tracks both the accumulated value and a
          "/valid/" counter.
        • `get_stats()` returns normalized statistics (sum / valid_count) for
          each metric per layer.

    ----------------------------------------
    Parameters
    ----------------------------------------
    model : torch.nn.Module
        The model whose modules will be monitored. All submodules returned by
        `model.named_modules()` will receive a forward hook.

    watcher : dict
        A dictionary mapping metric names to boolean flags. Keys must match the
        names used in `REG_FUNCTION_MAP`. Example:
            {
                "mean": True,
                "std": True,
                "grad_mean": True
            }

        Metrics not enabled here will not be computed.

    logger : logging.Logger
        A Logger used to report errors, debugging information, and warnings.

    ----------------------------------------
    Attributes
    ----------------------------------------
    stats : dict
        Nested dictionary storing accumulated statistics per layer. Normalized
        results are returned by `get_stats()`.

    handles : list
        A List of hook handles returned by `register_forward_hook`. These are
        stored to later remove all hooks safely.

    ----------------------------------------
    Usage Example
    ----------------------------------------
    >>> model: torch.nn.Module
    >>> watcher: dict[str, bool]
    >>> logger: logging.Logger
    >>> x: torch.Tensor
    >>> loss: torch.nn.Module   # Loss

    >>> monitor = HookMonitor(model, watcher, logger)
    >>> monitor.attach()
    >>> output = model(x)
    >>> loss.backward()
    >>> stats = monitor.get_stats()
    >>> monitor.remove()

    Or using a context manager:

    >>> with HookMonitor(model, watcher, logger) as monitor:
    ...     output = model(x)
    ...     loss.backward()
    >>> stats = monitor.get_stats()

    ----------------------------------------
    Notes
    ----------------------------------------
    - The gradient hook is attached to the activation tensor (module output),
      not to model parameters.
    - No gradients are tracked during forward hooks thanks to @torch.no_grad().
    - The monitor does not interfere with the training process: it only reads
      activations and gradients.
    - Missing '/valid/' counters trigger an error log and skip normalization for
      that metric.

    """
    def __init__(self, model: torch.nn.Module, watcher: dict, logger: logging.Logger):
        """
        Initialize a HookMonitor instance to track activation and gradient
        statistics across all modules of a PyTorch model.

        This constructor does not attach any hooks yet; it simply stores the
        monitoring configuration. Hooks are registered only when `attach()` or
        the context manager (`with HookMonitor(...)`) is used.

        Parameters
        ----------
        model : torch.nn.Module
            The model whose internal modules will be monitored. Every submodule
            returned by `model.named_modules()` will receive a forward hook.

        watcher : dict
            Dictionary of boolean flags controlling which statistics should be
            computed. Keys must match the names in `REG_FUNCTION_MAP`.
            Example:
                {
                    "mean": True,
                    "std": False,
                    "grad_mean": True
                }

        Any metric not enabled here will not be computed during execution.

        logger : logging.Logger
            Logging instance used for reporting errors, debug messages and
            warnings during monitoring operations.

            Attributes Initialized
            ----------------------
            model : torch.nn.Module
                Stored reference to the monitored model.

            watcher : dict
                The watcher configuration controlling metric activation.

            stats : dict
                Internal dictionary used to accumulate statistics across all layers.

            handles : list
                A List of hook handles created when calling `.attach()`. Each handle
                is later used to safely remove hooks with `.remove()`.

        Notes
        -----
        - No hooks are installed at construction time.
        - The monitor becomes active only after calling `.attach()` or entering
          a `with` block.
        """
        self.logger: logging.Logger = logger
        self.model: torch.nn.Module = model
        self.watcher: dict = watcher
        self.stats: dict = dict()
        self.handles: list = list()

    def _build_hook(self, name):

        @torch.no_grad()
        def hook(*args):
            _, _, act = args

            if torch.is_tensor(act):
                act_detached = act.detach().float()
                s = self.stats.setdefault(name, {})

                # Call functions:
                for function_name, compute_function in REG_FUNCTION_MAP.items():
                    if self.watcher.get(function_name, False) and not function_name.startswith('grad_'):
                        value = compute_function(act_detached, ...)
                        if value is not None:
                            s[function_name] = s.get(function_name, 0.0) + value
                            s[function_name + '/valid/'] = s.get(function_name + '/valid/', 0.0) + 1

                # Grad hook:
                def grad_hook(grad):
                    gd = grad.detach().float()
                    # Call functions:
                    for gd_function_name, gd_compute_function in REG_FUNCTION_MAP.items():
                        if self.watcher.get('grad_' + gd_function_name, False) and not gd_function_name.startswith('grad_'):
                            gd_function_name = 'grad_' + gd_function_name
                            gd_value = gd_compute_function(gd, ...)
                            if gd_value is not None:
                                s[gd_function_name] = s.get(gd_function_name, 0.0) + gd_value
                                s[gd_function_name + '/valid/'] = s.get(gd_function_name + '/valid/', 0.0) + 1

                if act.requires_grad:
                    act.register_hook(grad_hook)

        return hook

    def get_stats(self) -> dict:
        """
        Get the statistics of the hooks.
        :return: A dictionary with the statistics.
        """
        stats = dict()
        for layer_name, layer_stats in self.stats.items():
            sub_stats = dict()
            for key, item in layer_stats.items():
                if '/valid/' not in key:
                    if key + '/valid/' in layer_stats:
                        sub_stats[key] = item / layer_stats[key + '/valid/']
                    else:
                        self.logger.error(f"Key {key} has no valid count, skipping normalization.")
                        sub_stats[key] = item
            stats[layer_name] = sub_stats
        return stats

    def attach(self):
        """
        Registers all the hooks in the model.
        :return: The object.
        """
        for name, module in self.model.named_modules():
            h = module.register_forward_hook(self._build_hook(name))
            self.handles.append(h)
        return self

    def clear(self):
        """
        Clear stats' dictionary.
        :return: Nothing
        """
        self.stats.clear()

    def remove(self):
        """
        Remove all the hooks from the model.
        :return: Nothing.
        """
        for h in self.handles:
            h.remove()
        self.handles.clear()

    def __enter__(self):
        self.logger.debug("[Hooks] Attaching HookMonitor...")
        return self.attach()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.logger.debug("[Hooks] Removing HookMonitor...")
        self.remove()

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
#                          END OF FILE                              #
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #