# /*--------------------------------------------------------------------------------------------- #  * Copyright (c) 2025 STMicroelectronics. #  * All rights reserved. #  * #  * This software is licensed under terms that can be found in the LICENSE file in #  * the root directory of this software component. #  * If no LICENSE file comes with this software, it is provided AS-IS. #  *--------------------------------------------------------------------------------------------*/ #!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) Megvii Inc. All rights reserved. import math from copy import deepcopy import torch import torch.nn as nn __all__ = ["ModelEMA", "is_parallel"] def is_parallel(model): """check if model is in parallel mode.""" parallel_type = ( nn.parallel.DataParallel, nn.parallel.DistributedDataParallel, ) return isinstance(model, parallel_type) class ModelEMA: """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models Keep a moving average of everything in the model state_dict (parameters and buffers). This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage A smoothed version of the weights is necessary for some training schemes to perform well. This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. """ def __init__(self, model, decay=0.9999, updates=0): """ Args: model (nn.Module): model to apply EMA. decay (float): ema decay reate. updates (int): counter of EMA updates. """ # Create EMA(FP32) self.ema = deepcopy(model.module if is_parallel(model) else model).eval() self.updates = updates # decay exponential ramp (to help early epochs) self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) for p in self.ema.parameters(): p.requires_grad_(False) def update(self, model): # Update EMA parameters with torch.no_grad(): self.updates += 1 d = self.decay(self.updates) msd = ( model.module.state_dict() if is_parallel(model) else model.state_dict() ) # model state_dict for k, v in self.ema.state_dict().items(): if v.dtype.is_floating_point: v *= d v += (1.0 - d) * msd[k].detach()