| | |
| | import math |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from mmengine.model import ExponentialMovingAverage |
| | from torch import Tensor |
| |
|
| | from mmdet.registry import MODELS |
| |
|
| |
|
| | @MODELS.register_module() |
| | class ExpMomentumEMA(ExponentialMovingAverage): |
| | """Exponential moving average (EMA) with exponential momentum strategy, |
| | which is used in YOLOX. |
| | |
| | Args: |
| | model (nn.Module): The model to be averaged. |
| | momentum (float): The momentum used for updating ema parameter. |
| | Ema's parameter are updated with the formula: |
| | `averaged_param = (1-momentum) * averaged_param + momentum * |
| | source_param`. Defaults to 0.0002. |
| | gamma (int): Use a larger momentum early in training and gradually |
| | annealing to a smaller value to update the ema model smoothly. The |
| | momentum is calculated as |
| | `(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`. |
| | Defaults to 2000. |
| | interval (int): Interval between two updates. Defaults to 1. |
| | device (torch.device, optional): If provided, the averaged model will |
| | be stored on the :attr:`device`. Defaults to None. |
| | update_buffers (bool): if True, it will compute running averages for |
| | both the parameters and the buffers of the model. Defaults to |
| | False. |
| | """ |
| |
|
| | def __init__(self, |
| | model: nn.Module, |
| | momentum: float = 0.0002, |
| | gamma: int = 2000, |
| | interval=1, |
| | device: Optional[torch.device] = None, |
| | update_buffers: bool = False) -> None: |
| | super().__init__( |
| | model=model, |
| | momentum=momentum, |
| | interval=interval, |
| | device=device, |
| | update_buffers=update_buffers) |
| | assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' |
| | self.gamma = gamma |
| |
|
| | def avg_func(self, averaged_param: Tensor, source_param: Tensor, |
| | steps: int) -> None: |
| | """Compute the moving average of the parameters using the exponential |
| | momentum strategy. |
| | |
| | Args: |
| | averaged_param (Tensor): The averaged parameters. |
| | source_param (Tensor): The source parameters. |
| | steps (int): The number of times the parameters have been |
| | updated. |
| | """ |
| | momentum = (1 - self.momentum) * math.exp( |
| | -float(1 + steps) / self.gamma) + self.momentum |
| | averaged_param.lerp_(source_param, momentum) |
| |
|