LumiNet / modi_vae /networks.py
xyxingx's picture
Upload 9 files
4035e2e verified
from typing import Any, Dict, Iterable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
class EMAModel(nn.Module):
def __init__(self, model, decay=0.9999, use_num_updates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.decay = decay
# self.num_updates = 0 if use_num_updates else -1
self.register_buffer('num_updates', torch.zeros(1, dtype=torch.int64))
if not use_num_updates:
self.num_updates -= 1
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
# remove as '.'-character is not allowed in buffers
self.collected_params = []
def forward(self, model):
decay = self.decay
if self.num_updates.item() >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates.item()) / (10 + self.num_updates.item()))
one_minus_decay = 1.0 - decay
shadow_params = dict(self.named_buffers())
with torch.no_grad():
m_param = dict(model.named_parameters())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
shadow_params = dict(self.named_buffers())
m_param = dict(model.named_parameters())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, model):
"""
Save the current parameters for restoring later.
Args:
model: A model that parameters will be stored
"""
parameters = model.parameters()
self.collected_params = [param.clone() for param in parameters]
def restore(self, model):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
model: A model that to restore its parameters.
"""
parameters = model.parameters()
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)