|
|
from copy import deepcopy
|
|
|
import torch
|
|
|
import json
|
|
|
from collections import OrderedDict
|
|
|
import math
|
|
|
|
|
|
|
|
|
class ModelEma(torch.nn.Module):
|
|
|
"""EMA Model"""
|
|
|
def __init__(self, model, decay=0.9997, tau=0, device=None):
|
|
|
super(ModelEma, self).__init__()
|
|
|
|
|
|
self.module = deepcopy(model)
|
|
|
self.module.eval()
|
|
|
|
|
|
self.decay = decay
|
|
|
self.tau = tau
|
|
|
self.updates = 1
|
|
|
self.device = device
|
|
|
if self.device is not None:
|
|
|
self.module.to(device=device)
|
|
|
|
|
|
def _get_decay(self):
|
|
|
if self.tau == 0:
|
|
|
decay = self.decay
|
|
|
else:
|
|
|
decay = self.decay * (1 - math.exp(-self.updates / self.tau))
|
|
|
return decay
|
|
|
|
|
|
def _update(self, model, update_fn):
|
|
|
with torch.no_grad():
|
|
|
for ema_v, model_v in zip(
|
|
|
self.module.state_dict().values(), model.state_dict().values()):
|
|
|
if self.device is not None:
|
|
|
model_v = model_v.to(device=self.device)
|
|
|
ema_v.copy_(update_fn(ema_v, model_v))
|
|
|
|
|
|
def update(self, model):
|
|
|
decay = self._get_decay()
|
|
|
self._update(model, update_fn=lambda e, m: decay * e + (1. - decay) * m)
|
|
|
self.updates += 1
|
|
|
|
|
|
def set(self, model):
|
|
|
self._update(model, update_fn=lambda e, m: m)
|
|
|
|
|
|
|
|
|
class BestMetricSingle():
|
|
|
def __init__(self, init_res=0.0, better='large') -> None:
|
|
|
self.init_res = init_res
|
|
|
self.best_res = init_res
|
|
|
self.best_ep = -1
|
|
|
|
|
|
self.better = better
|
|
|
assert better in ['large', 'small']
|
|
|
|
|
|
def isbetter(self, new_res, old_res):
|
|
|
if self.better == 'large':
|
|
|
return new_res > old_res
|
|
|
if self.better == 'small':
|
|
|
return new_res < old_res
|
|
|
|
|
|
def update(self, new_res, ep):
|
|
|
if self.isbetter(new_res, self.best_res):
|
|
|
self.best_res = new_res
|
|
|
self.best_ep = ep
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep)
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
return self.__str__()
|
|
|
|
|
|
def summary(self) -> dict:
|
|
|
return {
|
|
|
'best_res': self.best_res,
|
|
|
'best_ep': self.best_ep,
|
|
|
}
|
|
|
|
|
|
|
|
|
class BestMetricHolder():
|
|
|
def __init__(self, init_res=0.0, better='large', use_ema=False) -> None:
|
|
|
self.best_all = BestMetricSingle(init_res, better)
|
|
|
self.use_ema = use_ema
|
|
|
if use_ema:
|
|
|
self.best_ema = BestMetricSingle(init_res, better)
|
|
|
self.best_regular = BestMetricSingle(init_res, better)
|
|
|
|
|
|
def update(self, new_res, epoch, is_ema=False):
|
|
|
"""
|
|
|
return if the results is the best.
|
|
|
"""
|
|
|
if not self.use_ema:
|
|
|
return self.best_all.update(new_res, epoch)
|
|
|
else:
|
|
|
if is_ema:
|
|
|
self.best_ema.update(new_res, epoch)
|
|
|
return self.best_all.update(new_res, epoch)
|
|
|
else:
|
|
|
self.best_regular.update(new_res, epoch)
|
|
|
return self.best_all.update(new_res, epoch)
|
|
|
|
|
|
def summary(self):
|
|
|
if not self.use_ema:
|
|
|
return self.best_all.summary()
|
|
|
|
|
|
res = {}
|
|
|
res.update({f'all_{k}':v for k,v in self.best_all.summary().items()})
|
|
|
res.update({f'regular_{k}':v for k,v in self.best_regular.summary().items()})
|
|
|
res.update({f'ema_{k}':v for k,v in self.best_ema.summary().items()})
|
|
|
return res
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
return json.dumps(self.summary(), indent=2)
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
return self.__repr__()
|
|
|
|
|
|
|
|
|
def clean_state_dict(state_dict):
|
|
|
new_state_dict = OrderedDict()
|
|
|
for k, v in state_dict.items():
|
|
|
if k[:7] == 'module.':
|
|
|
k = k[7:]
|
|
|
new_state_dict[k] = v
|
|
|
return new_state_dict
|
|
|
|