|
|
import torch
|
|
|
import torch.optim as optim
|
|
|
import torch.optim.lr_scheduler as lrs
|
|
|
|
|
|
import os
|
|
|
from collections import Counter
|
|
|
|
|
|
from model import Model
|
|
|
from utils import interact, Map
|
|
|
|
|
|
class Optimizer(object):
|
|
|
def __init__(self, args, model):
|
|
|
self.args = args
|
|
|
|
|
|
self.save_dir = os.path.join(self.args.save_dir, 'optim')
|
|
|
os.makedirs(self.save_dir, exist_ok=True)
|
|
|
|
|
|
if isinstance(model, Model):
|
|
|
model = model.model
|
|
|
|
|
|
|
|
|
kwargs_optimizer = {
|
|
|
'lr': args.lr,
|
|
|
'weight_decay': args.weight_decay
|
|
|
}
|
|
|
|
|
|
if args.optimizer == 'SGD':
|
|
|
optimizer_class = optim.SGD
|
|
|
kwargs_optimizer['momentum'] = args.momentum
|
|
|
elif args.optimizer == 'ADAM':
|
|
|
optimizer_class = optim.Adam
|
|
|
kwargs_optimizer['betas'] = args.betas
|
|
|
kwargs_optimizer['eps'] = args.epsilon
|
|
|
elif args.optimizer == 'RMSPROP':
|
|
|
optimizer_class = optim.RMSprop
|
|
|
kwargs_optimizer['eps'] = args.epsilon
|
|
|
|
|
|
|
|
|
if args.scheduler == 'step':
|
|
|
scheduler_class = lrs.MultiStepLR
|
|
|
kwargs_scheduler = {
|
|
|
'milestones': args.milestones,
|
|
|
'gamma': args.gamma,
|
|
|
}
|
|
|
elif args.scheduler == 'plateau':
|
|
|
scheduler_class = lrs.ReduceLROnPlateau
|
|
|
kwargs_scheduler = {
|
|
|
'mode': 'min',
|
|
|
'factor': args.gamma,
|
|
|
'patience': 10,
|
|
|
'verbose': True,
|
|
|
'threshold': 0,
|
|
|
'threshold_mode': 'abs',
|
|
|
'cooldown': 10,
|
|
|
}
|
|
|
|
|
|
self.kwargs_optimizer = kwargs_optimizer
|
|
|
self.scheduler_class = scheduler_class
|
|
|
self.kwargs_scheduler = kwargs_scheduler
|
|
|
|
|
|
def _get_optimizer(model):
|
|
|
|
|
|
class _Optimizer(optimizer_class):
|
|
|
def __init__(self, model, args, scheduler_class, kwargs_scheduler):
|
|
|
trainable = filter(lambda x: x.requires_grad, model.parameters())
|
|
|
super(_Optimizer, self).__init__(trainable, **kwargs_optimizer)
|
|
|
|
|
|
self.args = args
|
|
|
|
|
|
self._register_scheduler(scheduler_class, kwargs_scheduler)
|
|
|
|
|
|
def _register_scheduler(self, scheduler_class, kwargs_scheduler):
|
|
|
self.scheduler = scheduler_class(self, **kwargs_scheduler)
|
|
|
|
|
|
def schedule(self, metrics=None):
|
|
|
if isinstance(self, lrs.ReduceLROnPlateau):
|
|
|
self.scheduler.step(metrics)
|
|
|
else:
|
|
|
self.scheduler.step()
|
|
|
|
|
|
def get_last_epoch(self):
|
|
|
return self.scheduler.last_epoch
|
|
|
|
|
|
def get_lr(self):
|
|
|
return self.param_groups[0]['lr']
|
|
|
|
|
|
def get_last_lr(self):
|
|
|
return self.scheduler.get_last_lr()[0]
|
|
|
|
|
|
def state_dict(self):
|
|
|
state_dict = super(_Optimizer, self).state_dict()
|
|
|
state_dict['scheduler'] = self.scheduler.state_dict()
|
|
|
|
|
|
return state_dict
|
|
|
|
|
|
def load_state_dict(self, state_dict, epoch=None):
|
|
|
|
|
|
super(_Optimizer, self).load_state_dict(state_dict)
|
|
|
|
|
|
self.scheduler.load_state_dict(state_dict['scheduler'])
|
|
|
|
|
|
reschedule = False
|
|
|
if isinstance(self.scheduler, lrs.MultiStepLR):
|
|
|
if self.args.milestones != list(self.scheduler.milestones) or self.args.gamma != self.scheduler.gamma:
|
|
|
reschedule = True
|
|
|
|
|
|
if reschedule:
|
|
|
if epoch is None:
|
|
|
if self.scheduler.last_epoch > 1:
|
|
|
epoch = self.scheduler.last_epoch
|
|
|
else:
|
|
|
epoch = self.args.start_epoch - 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.scheduler.milestones = Counter(self.args.milestones)
|
|
|
self.scheduler.gamma = self.args.gamma
|
|
|
for i, group in enumerate(self.param_groups):
|
|
|
self.param_groups[i]['lr'] = group['initial_lr']
|
|
|
multiplier = 1
|
|
|
for milestone in self.scheduler.milestones:
|
|
|
if epoch >= milestone:
|
|
|
multiplier *= self.scheduler.gamma
|
|
|
|
|
|
self.param_groups[i]['lr'] *= multiplier
|
|
|
|
|
|
return _Optimizer(model, args, scheduler_class, kwargs_scheduler)
|
|
|
|
|
|
self.G = _get_optimizer(model.G)
|
|
|
if model.D is not None:
|
|
|
self.D = _get_optimizer(model.D)
|
|
|
else:
|
|
|
self.D = None
|
|
|
|
|
|
self.load(args.load_epoch)
|
|
|
|
|
|
def zero_grad(self):
|
|
|
self.G.zero_grad()
|
|
|
|
|
|
def step(self):
|
|
|
self.G.step()
|
|
|
|
|
|
def schedule(self, metrics=None):
|
|
|
self.G.schedule(metrics)
|
|
|
if self.D is not None:
|
|
|
self.D.schedule(metrics)
|
|
|
|
|
|
def get_last_epoch(self):
|
|
|
return self.G.get_last_epoch()
|
|
|
|
|
|
def get_lr(self):
|
|
|
return self.G.get_lr()
|
|
|
|
|
|
def get_last_lr(self):
|
|
|
return self.G.get_last_lr()
|
|
|
|
|
|
def state_dict(self):
|
|
|
state_dict = Map()
|
|
|
state_dict.G = self.G.state_dict()
|
|
|
if self.D is not None:
|
|
|
state_dict.D = self.D.state_dict()
|
|
|
|
|
|
return state_dict.toDict()
|
|
|
|
|
|
def load_state_dict(self, state_dict, epoch=None):
|
|
|
state_dict = Map(**state_dict)
|
|
|
self.G.load_state_dict(state_dict.G, epoch)
|
|
|
if self.D is not None:
|
|
|
self.D.load_state_dict(state_dict.D, epoch)
|
|
|
|
|
|
def _save_path(self, epoch=None):
|
|
|
epoch = epoch if epoch is not None else self.get_last_epoch()
|
|
|
save_path = os.path.join(self.save_dir, 'optim-{:d}.pt'.format(epoch))
|
|
|
|
|
|
return save_path
|
|
|
|
|
|
def save(self, epoch=None):
|
|
|
if epoch is None:
|
|
|
epoch = self.G.scheduler.last_epoch
|
|
|
torch.save(self.state_dict(), self._save_path(epoch))
|
|
|
|
|
|
def load(self, epoch):
|
|
|
if epoch > 0:
|
|
|
print('Loading optimizer from {}'.format(self._save_path(epoch)))
|
|
|
self.load_state_dict(torch.load(self._save_path(epoch), map_location=self.args.device), epoch=epoch)
|
|
|
|
|
|
elif epoch == 0:
|
|
|
pass
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
return
|
|
|
|
|
|
|