Rohan3's picture
deploy backend
4aabce3
import torch
class EMA:
def __init__(self, model, decay=0.999, warmup_steps=10000):
self.decay = decay
self.warmup_steps = warmup_steps
self.num_updates = 0
# self.params = {k: v for k, v in model.named_parameters() if v.requires_grad}
self.shadow = {k: v.clone().detach() for k, v in model.state_dict().items()}
self.backup = {}
@torch.no_grad()
def update(self, model):
self.num_updates += 1
model_state = model.state_dict()
decay = self.decay
if self.num_updates <= self.warmup_steps:
for k in self.shadow.keys():
if k in model_state:
self.shadow[k].copy_(model_state[k].detach())
return
for k, v in self.shadow.items():
if k not in model_state: continue
model_v = model_state[k].detach()
if not model_v.dtype.is_floating_point: v.copy_(model_v)
else: v.mul_(decay).add_(model_v, alpha=1.0 - decay)
def apply_shadow(self, model): # Applying EMA shadow params to model
self.backup = {k: v.clone().detach() for k, v in model.state_dict().items()}
model.load_state_dict(self.shadow, strict=False)
def restore(self, model): # Restoring model params from backup
# if not self.backup: raise ValueError("No backup found. Did you call apply_to() before restore()?")
if not self.backup: return
model.load_state_dict(self.backup, strict=False)
self.backup = {}