Spaces:
Sleeping
Sleeping
| import torch | |
| class EMA: | |
| def __init__(self, model, decay=0.999, warmup_steps=10000): | |
| self.decay = decay | |
| self.warmup_steps = warmup_steps | |
| self.num_updates = 0 | |
| # self.params = {k: v for k, v in model.named_parameters() if v.requires_grad} | |
| self.shadow = {k: v.clone().detach() for k, v in model.state_dict().items()} | |
| self.backup = {} | |
| def update(self, model): | |
| self.num_updates += 1 | |
| model_state = model.state_dict() | |
| decay = self.decay | |
| if self.num_updates <= self.warmup_steps: | |
| for k in self.shadow.keys(): | |
| if k in model_state: | |
| self.shadow[k].copy_(model_state[k].detach()) | |
| return | |
| for k, v in self.shadow.items(): | |
| if k not in model_state: continue | |
| model_v = model_state[k].detach() | |
| if not model_v.dtype.is_floating_point: v.copy_(model_v) | |
| else: v.mul_(decay).add_(model_v, alpha=1.0 - decay) | |
| def apply_shadow(self, model): # Applying EMA shadow params to model | |
| self.backup = {k: v.clone().detach() for k, v in model.state_dict().items()} | |
| model.load_state_dict(self.shadow, strict=False) | |
| def restore(self, model): # Restoring model params from backup | |
| # if not self.backup: raise ValueError("No backup found. Did you call apply_to() before restore()?") | |
| if not self.backup: return | |
| model.load_state_dict(self.backup, strict=False) | |
| self.backup = {} |