File size: 6,208 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager

import torch
from torch.nn import Module

from nemo.core.classes.common import FileIO, Serialization, Typing
from nemo.utils import logging

__all__ = ['NeuralModule']


class NeuralModule(Module, Typing, Serialization, FileIO):
    """
    Abstract class offering interface shared between all PyTorch Neural Modules.
    """

    @property
    def num_weights(self):
        """
        Utility property that returns the total number of parameters of NeuralModule.
        """
        return self._num_weights()

    @torch.jit.ignore
    def _num_weights(self):
        num: int = 0
        for p in self.parameters():
            if p.requires_grad:
                num += p.numel()
        return num

    def input_example(self, max_batch=None, max_dim=None):
        """
        Override this method if random inputs won't work
        Returns:
            A tuple sample of valid input data.
        """

        return None

    def freeze(self) -> None:
        r"""
        Freeze all params for inference.

        This method sets `requires_grad` to False for all parameters of the module.
        It also stores the original `requires_grad` state of each parameter in a dictionary,
        so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`.
        """
        grad_map = {}

        for pname, param in self.named_parameters():
            # Store the original grad state
            grad_map[pname] = param.requires_grad
            # Freeze the parameter
            param.requires_grad = False

        # Store the frozen grad map
        if not hasattr(self, '_frozen_grad_map'):
            self._frozen_grad_map = grad_map
        else:
            self._frozen_grad_map.update(grad_map)

        self.eval()

    def unfreeze(self, partial: bool = False) -> None:
        """
        Unfreeze all parameters for training.

        Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`).
        The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were
        previously unfrozen prior `freeze()`.

        Example:
            Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always.

            ```python
            model.encoder.freeze()  # Freezes all parameters in the encoder explicitly
            ```

            During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method.
            This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called,
            we should keep the encoder parameters frozen.

            ```python
            model.freeze()  # Freezes all parameters in the model; encoder remains frozen
            ```

            Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling
            `unfreeze(partial=True)`.

            ```python
            model.unfreeze(partial=True)  # Unfreezes only the decoder; encoder remains frozen
            ```

        Args:
            partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen
                when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`.
        """
        if partial and not hasattr(self, '_frozen_grad_map'):
            raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`")

        for pname, param in self.named_parameters():
            if not partial:
                # Unfreeze all parameters
                param.requires_grad = True
            else:
                # Unfreeze only parameters that were previously frozen

                # Check if the parameter was frozen
                if pname in self._frozen_grad_map:
                    param.requires_grad = self._frozen_grad_map[pname]
                else:
                    # Log a warning if the parameter was not found in the frozen grad map
                    logging.warning(
                        f"Parameter {pname} not found in list of previously frozen parameters. "
                        f"Unfreezing this parameter."
                    )
                    param.requires_grad = True

        # Clean up the frozen grad map
        if hasattr(self, '_frozen_grad_map'):
            delattr(self, '_frozen_grad_map')

        self.train()

    @contextmanager
    def as_frozen(self):
        """
        Context manager which temporarily freezes a module, yields control and finally unfreezes the module partially
        to return to original state.

        Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`).
        The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were
        previously unfrozen prior `freeze()`.

        Example:
            with model.as_frozen():  # by default, partial = True
                # Do something with the model
                pass

            # Model's parameters are now back to original state of requires_grad
        """
        training_mode = self.training
        self.freeze()
        try:
            yield
        finally:
            self.unfreeze(partial=True)

            if training_mode:
                self.train()
            else:
                self.eval()