File size: 6,741 Bytes
85ba398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#!/usr/bin/env python3

"""
Used for EMA tracking a given pytorch module. The user is responsible for calling step()
and setting the appropriate decay
"""

import copy
from dataclasses import dataclass, field
import logging

import torch

from omegaconf import II
from fairseq.dataclass import FairseqDataclass

try:
    from amp_C import multi_tensor_l2norm

    multi_tensor_l2norm_available = True
except ImportError:
    multi_tensor_l2norm_available = False

logger = logging.getLogger(__name__)


@dataclass
class EMAModuleConfig(FairseqDataclass):
    ema_decay: float = field(
        default=0.9999, metadata={"help": "decay for exponential moving average model"}
    )
    ema_fp32: bool = field(
        default=False,
        metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"},
    )
    add_missing_params: bool = True
    log_norms: bool = False


class EMAModule:
    """Exponential Moving Average of Fairseq Models"""

    def __init__(
        self,
        model,
        config: EMAModuleConfig,
        copy_model=True,
        device=None,
        skip_keys=None,
    ):
        """
        @param model model to initialize the EMA with
        @param config EMAConfig object with configuration like
        ema_decay, ema_update_freq, ema_fp32
        @param device If provided, copy EMA to this device (e.g. gpu).
        Otherwise EMA is in the same device as the model.
        """

        self.config = config

        if copy_model:
            self.model = copy.deepcopy(model)
            self.model.requires_grad_(False)
        else:
            self.model = model

        self.config = config
        self.decay = config.ema_decay
        self.skip_keys = skip_keys or set()
        self.add_missing_params = config.add_missing_params
        self.fp32_params = {}

        if device is not None:
            logging.info(f"Copying EMA model to device {device}")
            self.model = self.model.to(device=device)

        if self.config.ema_fp32:
            self.build_fp32_params()

        self.log_norms = config.log_norms and multi_tensor_l2norm_available
        self.logs = {}

    def build_fp32_params(self, state_dict=None):
        """
        Store a copy of the EMA params in fp32.
        If state dict is passed, the EMA params is copied from
        the provided state dict. Otherwise, it is copied from the
        current EMA model parameters.
        """
        if not self.config.ema_fp32:
            raise RuntimeError(
                "build_fp32_params should not be called if ema_fp32=False. "
                "Use ema_fp32=True if this is really intended."
            )

        if state_dict is None:
            state_dict = self.model.state_dict()

        def _to_float(t):
            return t.float() if torch.is_floating_point(t) else t

        for param_key in state_dict:
            if param_key in self.fp32_params:
                if param_key == "__sq_mom":
                    self.fp32_params[param_key] = state_dict[param_key]
                else:
                    self.fp32_params[param_key].copy_(state_dict[param_key])
            else:
                self.fp32_params[param_key] = _to_float(state_dict[param_key])
                if "__sq_mom" in self.fp32_params:
                    self.fp32_params["__sq_mom"][param_key] = torch.zeros_like(
                        self.fp32_params[param_key]
                    )

    def restore(self, state_dict, build_fp32_params=False):
        """Load data from a model spec into EMA model"""
        self.model.load_state_dict(state_dict, strict=False)
        if build_fp32_params:
            self.build_fp32_params(state_dict)

    def set_decay(self, decay, weight_decay=None):
        self.decay = decay
        if weight_decay is not None:
            self.weight_decay = weight_decay

    def get_decay(self):
        return self.decay

    def _step_internal(self, new_model):
        """One update of the EMA model based on new model weights"""
        decay = self.decay

        ema_state_dict = {}
        ema_params = (
            self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
        )

        new_p = []
        ema_p = []

        for key, param in new_model.named_parameters():
            if isinstance(param, dict):
                continue

            if not self.add_missing_params and key not in ema_params:
                continue

            try:
                ema_param = ema_params[key]
            except KeyError:
                ema_param = (
                    param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
                )
                ema_params[key] = ema_param

            if param.shape != ema_param.shape:
                raise ValueError(
                    "incompatible tensor shapes between model param and ema param"
                    + "{} vs. {}".format(param.shape, ema_param.shape)
                )

            if "version" in key:
                # Do not decay a model.version pytorch param
                continue

            lr = 1 - decay

            if key in self.skip_keys or not param.requires_grad:
                ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
                ema_param = ema_params[key]
            else:
                if self.log_norms:
                    new_p.append(param)
                    ema_p.append(ema_param)

                ema_param.mul_(1 - lr)
                ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=lr)

            ema_state_dict[key] = ema_param

        for key, param in new_model.named_buffers():
            ema_state_dict[key] = param

        if self.log_norms:
            if "model_norm" in self.logs:
                self.prev_model_norm = self.logs["model_norm"]

            chunk_size = 2048 * 32
            has_inf = torch.zeros(
                (1, 1), dtype=torch.int, device=next(new_model.parameters()).device
            )

            new_norm = multi_tensor_l2norm(chunk_size, has_inf, [new_p], False)
            old_norm = multi_tensor_l2norm(chunk_size, has_inf, [ema_p], False)

            self.logs["model_norm"] = new_norm[0]
            self.logs["ema_norm"] = old_norm[0]

        self.restore(ema_state_dict, build_fp32_params=False)

    @torch.no_grad()
    def step(self, new_model):
        self._step_internal(new_model)

    def reverse(self, model):
        """
        Load the model parameters from EMA model.
        Useful for inference or fine-tuning from the EMA model.
        """
        d = self.model.state_dict()
        if "_ema" in d:
            del d["_ema"]

        model.load_state_dict(d, strict=False)
        return model