| """ |
| D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement |
| Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. |
| --------------------------------------------------------------------------------- |
| Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR) |
| Copyright (c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
|
|
|
|
| import torch |
| import torch.nn as nn |
|
|
| import math |
| from copy import deepcopy |
|
|
| from ..core import register |
| from ..misc import dist_utils |
|
|
| __all__ = ['ModelEMA'] |
|
|
|
|
| @register() |
| class ModelEMA(object): |
| """ |
| Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models |
| Keep a moving average of everything in the model state_dict (parameters and buffers). |
| This is intended to allow functionality like |
| https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage |
| A smoothed version of the weights is necessary for some training schemes to perform well. |
| This class is sensitive where it is initialized in the sequence of model init, |
| GPU assignment and distributed training wrappers. |
| """ |
| def __init__(self, model: nn.Module, decay: float=0.9999, warmups: int=1000, start: int=0): |
| super().__init__() |
|
|
| self.module = deepcopy(dist_utils.de_parallel(model)).eval() |
| |
| |
|
|
| self.decay = decay |
| self.warmups = warmups |
| self.before_start = 0 |
| self.start = start |
| self.updates = 0 |
| if warmups == 0: |
| self.decay_fn = lambda x: decay |
| else: |
| self.decay_fn = lambda x: decay * (1 - math.exp(-x / warmups)) |
|
|
| for p in self.module.parameters(): |
| p.requires_grad_(False) |
|
|
|
|
| def update(self, model: nn.Module): |
| if self.before_start < self.start: |
| self.before_start += 1 |
| return |
| |
| with torch.no_grad(): |
| self.updates += 1 |
| d = self.decay_fn(self.updates) |
| msd = dist_utils.de_parallel(model).state_dict() |
| for k, v in self.module.state_dict().items(): |
| if v.dtype.is_floating_point: |
| v *= d |
| v += (1 - d) * msd[k].detach() |
|
|
| def to(self, *args, **kwargs): |
| self.module = self.module.to(*args, **kwargs) |
| return self |
|
|
| def state_dict(self, ): |
| return dict(module=self.module.state_dict(), updates=self.updates) |
|
|
| def load_state_dict(self, state, strict=True): |
| self.module.load_state_dict(state['module'], strict=strict) |
| if 'updates' in state: |
| self.updates = state['updates'] |
|
|
| def forwad(self, ): |
| raise RuntimeError('ema...') |
|
|
| def extra_repr(self) -> str: |
| return f'decay={self.decay}, warmups={self.warmups}' |
|
|
|
|
|
|
| class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): |
| """Maintains moving averages of model parameters using an exponential decay. |
| ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` |
| `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_ |
| is used to compute the EMA. |
| """ |
| def __init__(self, model, decay, device="cpu", use_buffers=True): |
|
|
| self.decay_fn = lambda x: decay * (1 - math.exp(-x / 2000)) |
|
|
| def ema_avg(avg_model_param, model_param, num_averaged): |
| decay = self.decay_fn(num_averaged) |
| return decay * avg_model_param + (1 - decay) * model_param |
|
|
| super().__init__(model, device, ema_avg, use_buffers=use_buffers) |
|
|