import torch class EMA: def __init__(self, model, decay=0.999, warmup_steps=10000): self.decay = decay self.warmup_steps = warmup_steps self.num_updates = 0 # self.params = {k: v for k, v in model.named_parameters() if v.requires_grad} self.shadow = {k: v.clone().detach() for k, v in model.state_dict().items()} self.backup = {} @torch.no_grad() def update(self, model): self.num_updates += 1 model_state = model.state_dict() decay = self.decay if self.num_updates <= self.warmup_steps: for k in self.shadow.keys(): if k in model_state: self.shadow[k].copy_(model_state[k].detach()) return for k, v in self.shadow.items(): if k not in model_state: continue model_v = model_state[k].detach() if not model_v.dtype.is_floating_point: v.copy_(model_v) else: v.mul_(decay).add_(model_v, alpha=1.0 - decay) def apply_shadow(self, model): # Applying EMA shadow params to model self.backup = {k: v.clone().detach() for k, v in model.state_dict().items()} model.load_state_dict(self.shadow, strict=False) def restore(self, model): # Restoring model params from backup # if not self.backup: raise ValueError("No backup found. Did you call apply_to() before restore()?") if not self.backup: return model.load_state_dict(self.backup, strict=False) self.backup = {}