| |
|
|
| """ |
| This module has the EMA class used to store a copy of the exponentially decayed |
| model params. |
| |
| Typical usage of EMA class involves initializing an object using an existing |
| model (random or from a seed model) and setting the config like ema_decay, |
| ema_start_update which determine how the EMA model is updated. After every |
| update of the model i.e. at the end of the train_step, the EMA should be updated |
| by passing the new model to the EMA.step function. The EMA model state dict |
| can be stored in the extra state under the key of "ema" and dumped |
| into a checkpoint and loaded. The EMA object can be passed to tasks |
| by setting task.uses_ema property. |
| EMA is a smoothed/ensemble model which might have better performance |
| when used for inference or further fine-tuning. EMA class has a |
| reverse function to load the EMA params into a model and use it |
| like a regular model. |
| |
| This implementation is used for trainer-level ema tracking. For EMA tracking |
| inside the model, please use fairseq/modules/ema_module.py instead. |
| """ |
|
|
| import copy |
| import logging |
|
|
| import torch |
|
|
| from fairseq import checkpoint_utils |
|
|
|
|
| class EMA(object): |
| """Exponential Moving Average of Fairseq Models |
| EMA keeps a copy of the exponentially decayed model params. |
| The set of params should include both gradient-descent and |
| non-gradient descent params, such as batch mean/var and buffers. |
| This is a modified implementation of |
| the open source code in https://github.com/zhawe01/fairseq-gec.git, |
| and internal source code in |
| fbcode/mobile-vision/projects/classification_pytorch/lib/utils/model_ema.py. |
| |
| Similar to TF EMA. |
| https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage. |
| EMA provides a averaged and smoothed set of model weights, and has been shown to |
| improve vision models. EMA class does all necessary functions to update, reload, |
| or init EMA methods. |
| |
| EMA object is initialized from an arbitrary model. By default, it is stored in |
| the same device (unless device specified at initialization) and with the |
| same precision as the model (unless ema_fp32 is True). ema_fp32 is recommended. |
| This stores the EMA parameters in fp32 only for the EMA update step, and |
| is used at the default precision otherwise. |
| EMA is usually enabled using EMAConfig with store_ema=True. Some important |
| parameters to configure EMA are |
| 1) ema_decay - The decay of EMA |
| 2) ema_update_freq - EMA is updated every this many model updates. |
| 3) ema_start_update - Start EMA update after this many model updates [default 0] |
| |
| Key methods: |
| 1) step - One update of EMA using new model |
| 2) restore - Update EMA from a state dict |
| 3) reverse - Load EMA into a model |
| 4) get_decay, _set_decay - Used to get or set the decay. Note _set_decay is |
| called from step. |
| 5) build_fp32_params - Used to initialize or update the fp32 copy of EMA params. |
| Note this is enabled only when ema_fp32=True |
| """ |
|
|
| def __init__(self, model, config, device=None, skip_keys=None): |
| """ |
| @param model model to initialize the EMA with |
| @param config EMAConfig object with configuration like |
| ema_decay, ema_update_freq, ema_fp32 |
| @param device If provided, copy EMA to this device (e.g. gpu). |
| Otherwise EMA is in the same device as the model. |
| """ |
|
|
| self.decay = config.ema_decay |
| self.model = copy.deepcopy(model) |
| self.model.requires_grad_(False) |
| self.config = config |
| self.skip_keys = skip_keys or set() |
| self.fp32_params = {} |
|
|
| if self.config.ema_seed_model is not None: |
| state = checkpoint_utils.load_ema_from_checkpoint( |
| self.config.ema_seed_model |
| ) |
| self.model.load_state_dict(state["model"], strict=True) |
|
|
| if device is not None: |
| logging.info(f"Copying EMA model to device {device}") |
| self.model = self.model.to(device=device) |
|
|
| if self.config.ema_fp32: |
| self.build_fp32_params() |
|
|
| self.update_freq_counter = 0 |
|
|
| def get_model(self): |
| return self.model |
|
|
| def build_fp32_params(self, state_dict=None): |
| """ |
| Store a copy of the EMA params in fp32. |
| If state dict is passed, the EMA params is copied from |
| the provided state dict. Otherwise, it is copied from the |
| current EMA model parameters. |
| """ |
| if not self.config.ema_fp32: |
| raise RuntimeError( |
| "build_fp32_params should not be called if ema_fp32=False. " |
| "Use ema_fp32=True if this is really intended." |
| ) |
|
|
| if state_dict is None: |
| state_dict = self.model.state_dict() |
|
|
| def _to_float(t): |
| return t.float() if torch.is_floating_point(t) else t |
|
|
| for param_key in state_dict: |
| if param_key in self.fp32_params: |
| self.fp32_params[param_key].copy_(state_dict[param_key]) |
| else: |
| self.fp32_params[param_key] = _to_float(state_dict[param_key]) |
|
|
| def restore(self, state_dict, build_fp32_params=False): |
| """Load data from a model spec into EMA model""" |
| self.model.load_state_dict(state_dict, strict=False) |
| if build_fp32_params: |
| self.build_fp32_params(state_dict) |
|
|
| def _set_decay(self, decay): |
| self.decay = decay |
|
|
| def get_decay(self): |
| return self.decay |
|
|
| def _step_internal(self, new_model, updates=None): |
| """One update of the EMA model based on new model weights""" |
| decay = self.decay |
|
|
| ema_state_dict = {} |
| ema_params = ( |
| self.fp32_params if self.config.ema_fp32 else self.model.state_dict() |
| ) |
| for key, param in new_model.state_dict().items(): |
| if isinstance(param, dict): |
| continue |
| try: |
| ema_param = ema_params[key] |
| except KeyError: |
| ema_param = ( |
| param.float().clone() if param.ndim == 1 else copy.deepcopy(param) |
| ) |
|
|
| if param.shape != ema_param.shape: |
| raise ValueError( |
| "incompatible tensor shapes between model param and ema param" |
| + "{} vs. {}".format(param.shape, ema_param.shape) |
| ) |
|
|
| if "version" in key: |
| |
| continue |
|
|
| if key in self.skip_keys: |
| ema_param = param.to(dtype=ema_param.dtype).clone() |
| else: |
| ema_param.mul_(decay) |
| ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay) |
| ema_state_dict[key] = ema_param |
| self.restore(ema_state_dict, build_fp32_params=False) |
|
|
| def step(self, new_model, updates=None): |
| """ |
| One update of EMA which is done every self.config.ema_update_freq |
| updates of the model. |
| |
| @param updates The current number of model updates done. |
| Decay is set of 0 if model updates < ema_start_update, which means |
| the model will be simply copied over to the EMA. |
| When model updates >= ema_start_updates, then EMA is updated with |
| a decay of self.config.ema_decay. |
| """ |
| if updates is not None: |
| self._set_decay( |
| 0 if updates < self.config.ema_start_update else self.config.ema_decay |
| ) |
| if self.config.ema_update_freq > 1: |
| self.update_freq_counter += 1 |
| if self.update_freq_counter >= self.config.ema_update_freq: |
| self._step_internal(new_model, updates) |
| self.update_freq_counter = 0 |
| else: |
| self._step_internal(new_model, updates) |
|
|
| def reverse(self, model): |
| """ |
| Load the model parameters from EMA model. |
| Useful for inference or fine-tuning from the EMA model. |
| """ |
| d = self.model.state_dict() |
| if "_ema" in d: |
| del d["_ema"] |
|
|
| model.load_state_dict(d, strict=False) |
| return model |
|
|