|
|
import torch |
|
|
import torch.optim as optim |
|
|
import torch.optim.lr_scheduler as lrs |
|
|
|
|
|
import os |
|
|
from collections import Counter |
|
|
|
|
|
from model import Model |
|
|
from utils import interact, Map |
|
|
|
|
|
class Optimizer(object): |
|
|
def __init__(self, args, model): |
|
|
self.args = args |
|
|
|
|
|
self.save_dir = os.path.join(self.args.save_dir, 'optim') |
|
|
os.makedirs(self.save_dir, exist_ok=True) |
|
|
|
|
|
if isinstance(model, Model): |
|
|
model = model.model |
|
|
|
|
|
|
|
|
kwargs_optimizer = { |
|
|
'lr': args.lr, |
|
|
'weight_decay': args.weight_decay |
|
|
} |
|
|
|
|
|
if args.optimizer == 'SGD': |
|
|
optimizer_class = optim.SGD |
|
|
kwargs_optimizer['momentum'] = args.momentum |
|
|
elif args.optimizer == 'ADAM': |
|
|
optimizer_class = optim.Adam |
|
|
kwargs_optimizer['betas'] = args.betas |
|
|
kwargs_optimizer['eps'] = args.epsilon |
|
|
elif args.optimizer == 'RMSPROP': |
|
|
optimizer_class = optim.RMSprop |
|
|
kwargs_optimizer['eps'] = args.epsilon |
|
|
|
|
|
|
|
|
if args.scheduler == 'step': |
|
|
scheduler_class = lrs.MultiStepLR |
|
|
kwargs_scheduler = { |
|
|
'milestones': args.milestones, |
|
|
'gamma': args.gamma, |
|
|
} |
|
|
elif args.scheduler == 'plateau': |
|
|
scheduler_class = lrs.ReduceLROnPlateau |
|
|
kwargs_scheduler = { |
|
|
'mode': 'min', |
|
|
'factor': args.gamma, |
|
|
'patience': 10, |
|
|
'verbose': True, |
|
|
'threshold': 0, |
|
|
'threshold_mode': 'abs', |
|
|
'cooldown': 10, |
|
|
} |
|
|
|
|
|
self.kwargs_optimizer = kwargs_optimizer |
|
|
self.scheduler_class = scheduler_class |
|
|
self.kwargs_scheduler = kwargs_scheduler |
|
|
|
|
|
def _get_optimizer(model): |
|
|
|
|
|
class _Optimizer(optimizer_class): |
|
|
def __init__(self, model, args, scheduler_class, kwargs_scheduler): |
|
|
trainable = filter(lambda x: x.requires_grad, model.parameters()) |
|
|
super(_Optimizer, self).__init__(trainable, **kwargs_optimizer) |
|
|
|
|
|
self.args = args |
|
|
|
|
|
self._register_scheduler(scheduler_class, kwargs_scheduler) |
|
|
|
|
|
def _register_scheduler(self, scheduler_class, kwargs_scheduler): |
|
|
self.scheduler = scheduler_class(self, **kwargs_scheduler) |
|
|
|
|
|
def schedule(self, metrics=None): |
|
|
if isinstance(self, lrs.ReduceLROnPlateau): |
|
|
self.scheduler.step(metrics) |
|
|
else: |
|
|
self.scheduler.step() |
|
|
|
|
|
def get_last_epoch(self): |
|
|
return self.scheduler.last_epoch |
|
|
|
|
|
def get_lr(self): |
|
|
return self.param_groups[0]['lr'] |
|
|
|
|
|
def get_last_lr(self): |
|
|
return self.scheduler.get_last_lr()[0] |
|
|
|
|
|
def state_dict(self): |
|
|
state_dict = super(_Optimizer, self).state_dict() |
|
|
state_dict['scheduler'] = self.scheduler.state_dict() |
|
|
|
|
|
return state_dict |
|
|
|
|
|
def load_state_dict(self, state_dict, epoch=None): |
|
|
|
|
|
super(_Optimizer, self).load_state_dict(state_dict) |
|
|
|
|
|
self.scheduler.load_state_dict(state_dict['scheduler']) |
|
|
|
|
|
reschedule = False |
|
|
if isinstance(self.scheduler, lrs.MultiStepLR): |
|
|
if self.args.milestones != list(self.scheduler.milestones) or self.args.gamma != self.scheduler.gamma: |
|
|
reschedule = True |
|
|
|
|
|
if reschedule: |
|
|
if epoch is None: |
|
|
if self.scheduler.last_epoch > 1: |
|
|
epoch = self.scheduler.last_epoch |
|
|
else: |
|
|
epoch = self.args.start_epoch - 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.scheduler.milestones = Counter(self.args.milestones) |
|
|
self.scheduler.gamma = self.args.gamma |
|
|
for i, group in enumerate(self.param_groups): |
|
|
self.param_groups[i]['lr'] = group['initial_lr'] |
|
|
multiplier = 1 |
|
|
for milestone in self.scheduler.milestones: |
|
|
if epoch >= milestone: |
|
|
multiplier *= self.scheduler.gamma |
|
|
|
|
|
self.param_groups[i]['lr'] *= multiplier |
|
|
|
|
|
return _Optimizer(model, args, scheduler_class, kwargs_scheduler) |
|
|
|
|
|
self.G = _get_optimizer(model.G) |
|
|
if model.D is not None: |
|
|
self.D = _get_optimizer(model.D) |
|
|
else: |
|
|
self.D = None |
|
|
|
|
|
self.load(args.load_epoch) |
|
|
|
|
|
def zero_grad(self): |
|
|
self.G.zero_grad() |
|
|
|
|
|
def step(self): |
|
|
self.G.step() |
|
|
|
|
|
def schedule(self, metrics=None): |
|
|
self.G.schedule(metrics) |
|
|
if self.D is not None: |
|
|
self.D.schedule(metrics) |
|
|
|
|
|
def get_last_epoch(self): |
|
|
return self.G.get_last_epoch() |
|
|
|
|
|
def get_lr(self): |
|
|
return self.G.get_lr() |
|
|
|
|
|
def get_last_lr(self): |
|
|
return self.G.get_last_lr() |
|
|
|
|
|
def state_dict(self): |
|
|
state_dict = Map() |
|
|
state_dict.G = self.G.state_dict() |
|
|
if self.D is not None: |
|
|
state_dict.D = self.D.state_dict() |
|
|
|
|
|
return state_dict.toDict() |
|
|
|
|
|
def load_state_dict(self, state_dict, epoch=None): |
|
|
state_dict = Map(**state_dict) |
|
|
self.G.load_state_dict(state_dict.G, epoch) |
|
|
if self.D is not None: |
|
|
self.D.load_state_dict(state_dict.D, epoch) |
|
|
|
|
|
def _save_path(self, epoch=None): |
|
|
epoch = epoch if epoch is not None else self.get_last_epoch() |
|
|
save_path = os.path.join(self.save_dir, 'optim-{:d}.pt'.format(epoch)) |
|
|
|
|
|
return save_path |
|
|
|
|
|
def save(self, epoch=None): |
|
|
if epoch is None: |
|
|
epoch = self.G.scheduler.last_epoch |
|
|
torch.save(self.state_dict(), self._save_path(epoch)) |
|
|
|
|
|
def load(self, epoch): |
|
|
if epoch > 0: |
|
|
print('Loading optimizer from {}'.format(self._save_path(epoch))) |
|
|
self.load_state_dict(torch.load(self._save_path(epoch), map_location=self.args.device), epoch=epoch) |
|
|
|
|
|
elif epoch == 0: |
|
|
pass |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
return |
|
|
|
|
|
|