File size: 10,279 Bytes
f4cade0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# mypy: allow-untyped-defs
import warnings
import weakref
from typing import Callable, Optional

import torch
from torch.autograd.graph import register_multi_grad_hook
from torch.nn.modules.module import (
    register_module_forward_hook,
    register_module_forward_pre_hook,
)
from torch.utils._pytree import tree_flatten


__all__ = ["ModTracker"]


class ModTracker:
    """

    ``ModTracker`` is a context manager that tracks the nn.Module hierarchy during execution

    so that other system can query which Module is currently being executed (or its backward is being

    executed).



    You can access the ``parents`` attribute on this context manager to get the set of all the

    Modules currently being executed via their fqn (fully qualified name, also used as the key within

    the state_dict).

    You can access the ``is_bw`` attribute to know if you are currently running in backward or not.



    Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag

    will remain ``True`` after the forward until another Module is executed. If you need it to be

    more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance

    is possible but not done yet, please submit an issue requesting this if you need it.



    Example usage



    .. code-block:: python



        mod = torch.nn.Linear(2, 2)



        with ModTracker() as tracker:

            # Access anything during the forward pass

            def my_linear(m1, m2, bias):

                print(f"Current modules: {tracker.parents}")

                return torch.mm(m1, m2.t()) + bias



            torch.nn.functional.linear = my_linear



            mod(torch.rand(2, 2))



    """

    parents: set[str]
    """

    A Set containing the fqn for each module currently running their forward

    """

    def __init__(self):
        self.parents = {"Global"}
        self._active_module_cnt = {}
        self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
        self._seen_modules: weakref.WeakSet = weakref.WeakSet()
        self._has_callback = False
        self._post_bw_callbacks_to_enqueue: list[Callable] = []
        self._user_pre_fw_hook = None
        self._user_post_fw_hook = None
        self._user_pre_bw_hook = None
        self._user_post_bw_hook = None

    def _maybe_set_engine_callback(self):
        # This assumes no concurrent calls to backward
        if self._has_callback:
            return

        for post_bw_callback in reversed(self._post_bw_callbacks_to_enqueue):
            torch.autograd.Variable._execution_engine.queue_callback(post_bw_callback)
        self._post_bw_callbacks_to_enqueue.clear()

        def callback():
            self.parents = {"Global"}
            self._has_callback = False

        torch.autograd.Variable._execution_engine.queue_callback(callback)
        self._has_callback = True

    @property
    def is_bw(self):
        """

        A boolean marking if this is currently running during the backward pass or not

        """
        return torch._C._current_graph_task_id() != -1

    def get_known_fqn(self, mod):
        """

        Return the fqn for the given module if it is known to the ``ModTracker``, otherwise ``None``.

        """
        return self._known_modules.get(mod, None)

    def register_user_hooks(

        self,

        pre_fw_hook: Optional[Callable] = None,

        post_fw_hook: Optional[Callable] = None,

        pre_bw_hook: Optional[Callable] = None,

        post_bw_hook: Optional[Callable] = None,

    ):
        """

        Registers user-specified hooks to be called before/after the forward/backward pass for each

        module tracked by the ``ModTracker``. One or more can be ``None``.

        Args:

            pre_fw_hook (Callable, optional): A hook to be called before the forward pass for the

                module. It should have the following signature:

                pre_fw_hook (module, input) -> None

            post_fw_hook (Callable, optional): A hook to be called after the forward pass for the

                module. It should have the following signature:

                post_fw_hook (module, input, output) -> None

            pre_bw_hook (Callable, optional): A multi-grad hook to be called on all the outputs of

                the module that require gradients. It should have the following signature:

                pre_bw_hook (module, grad_output) -> None

            post_bw_hook (Callable, optional): A multi-grad hook to be called on all the inputs of

                the module that require gradients. It should have the following signature:

                post_bw_hook (module, grad_input) -> None

        Raises:

            AssertionError: If a new hook is provided when one is already registered.

        Note:

            If the module is not alive during the backward pass, the pre_bw_hook and post_bw_hook will

            will receive None as the module argument.

            The module fqn will be present in the ``parents`` attribute when each of the hooks is called.

            Hooks are intended to be used as markers only not to modify the inputs/outputs.

        """

        def set_hook(hook, user_hook, hook_name):
            if hook is not None and user_hook is not None:
                raise AssertionError(
                    f"Only one {hook_name} can be registered at a time"
                    f" Clear the existing hook by calling ``clear_user_hooks`` before registering a new one"
                )
            return hook

        self._user_pre_fw_hook = set_hook(
            pre_fw_hook, self._user_pre_fw_hook, "pre_fw_hook"
        )
        self._user_post_fw_hook = set_hook(
            post_fw_hook, self._user_post_fw_hook, "post_fw_hook"
        )
        self._user_pre_bw_hook = set_hook(
            pre_bw_hook, self._user_pre_bw_hook, "pre_bw_hook"
        )
        self._user_post_bw_hook = set_hook(
            post_bw_hook, self._user_post_bw_hook, "post_bw_hook"
        )

    def clear_user_hooks(self):
        """

        Clears the user specified hooks registered with ``register_user_hooks``

        """
        self._user_pre_fw_hook = None
        self._user_post_fw_hook = None
        self._user_pre_bw_hook = None
        self._user_post_bw_hook = None

    def _get_mod_name(self, mod):
        if mod not in self._known_modules:
            self._known_modules[mod] = type(mod).__name__
        mod_name = self._known_modules[mod]
        if mod not in self._seen_modules:
            for name, submod in mod.named_children():
                self._known_modules[submod] = f"{mod_name}.{name}"
                self._get_mod_name(submod)
            self._seen_modules.add(mod)
        return mod_name

    def _get_append_fn(self, w_mod, name, is_bw):
        def fn(*args):
            if is_bw:
                self._maybe_set_engine_callback()
            if name in self.parents and not self.is_bw:

                def custom_formatwarning(msg, category, filename, lineno, line=None):
                    return f"{filename}:{lineno}: {category.__name__}: {msg} \n"

                warnings.formatwarning = custom_formatwarning
                warnings.warn(
                    "The module hierarchy tracking maybe be messed up."
                    " Please file a bug to PyTorch, if it is the case."
                )
            if name not in self.parents:
                self._active_module_cnt[name] = 1
                self.parents.add(name)
            else:
                self._active_module_cnt[name] += 1

            if self._user_pre_bw_hook is not None and is_bw:
                self._user_pre_bw_hook(w_mod(), args)

        return fn

    def _get_pop_fn(self, w_mod, name, is_bw):
        def fn(*args):
            if self._user_post_bw_hook is not None and is_bw:
                self._user_post_bw_hook(w_mod(), args)
            if name in self.parents:
                self._active_module_cnt[name] -= 1
                if self._active_module_cnt[name] == 0:
                    self.parents.remove(name)
            elif not self.is_bw:
                # Due to some input/output not requiring gradients, we cannot enforce
                # proper nesting in backward
                raise RuntimeError(
                    "The Module hierarchy tracking is wrong. Report a bug to PyTorch"
                )

        return fn

    def _fw_pre_hook(self, mod, input):
        name = self._get_mod_name(mod)
        w_mod = weakref.ref(mod)
        self._get_append_fn(w_mod, name, False)()
        if self._user_pre_fw_hook is not None:
            self._user_pre_fw_hook(mod, input)
        args, _ = tree_flatten(input)
        tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
        if not self.is_bw:
            if tensors:
                register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True))
            else:
                self._post_bw_callbacks_to_enqueue.append(
                    self._get_pop_fn(w_mod, name, True)
                )

    def _fw_post_hook(self, mod, input, output):
        name = self._get_mod_name(mod)
        w_mod = weakref.ref(mod)
        if self._user_post_fw_hook is not None:
            self._user_post_fw_hook(mod, input, output)
        self._get_pop_fn(w_mod, name, False)()
        args, _ = tree_flatten(output)
        tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
        if not self.is_bw and tensors:
            register_multi_grad_hook(
                tensors, self._get_append_fn(w_mod, name, True), mode="any"
            )

    def __enter__(self):
        self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
        self._fw_post_handle = register_module_forward_hook(
            self._fw_post_hook, always_call=True
        )
        return self

    def __exit__(self, *args):
        self._fw_pre_handle.remove()
        self._fw_post_handle.remove()