| | import torch
|
| | import torch.optim as optim
|
| | import torch.optim.lr_scheduler as lrs
|
| |
|
| | import os
|
| | from collections import Counter
|
| |
|
| | from model import Model
|
| | from utils import interact, Map
|
| |
|
| | class Optimizer(object):
|
| | def __init__(self, args, model):
|
| | self.args = args
|
| |
|
| | self.save_dir = os.path.join(self.args.save_dir, 'optim')
|
| | os.makedirs(self.save_dir, exist_ok=True)
|
| |
|
| | if isinstance(model, Model):
|
| | model = model.model
|
| |
|
| |
|
| | kwargs_optimizer = {
|
| | 'lr': args.lr,
|
| | 'weight_decay': args.weight_decay
|
| | }
|
| |
|
| | if args.optimizer == 'SGD':
|
| | optimizer_class = optim.SGD
|
| | kwargs_optimizer['momentum'] = args.momentum
|
| | elif args.optimizer == 'ADAM':
|
| | optimizer_class = optim.Adam
|
| | kwargs_optimizer['betas'] = args.betas
|
| | kwargs_optimizer['eps'] = args.epsilon
|
| | elif args.optimizer == 'RMSPROP':
|
| | optimizer_class = optim.RMSprop
|
| | kwargs_optimizer['eps'] = args.epsilon
|
| |
|
| |
|
| | if args.scheduler == 'step':
|
| | scheduler_class = lrs.MultiStepLR
|
| | kwargs_scheduler = {
|
| | 'milestones': args.milestones,
|
| | 'gamma': args.gamma,
|
| | }
|
| | elif args.scheduler == 'plateau':
|
| | scheduler_class = lrs.ReduceLROnPlateau
|
| | kwargs_scheduler = {
|
| | 'mode': 'min',
|
| | 'factor': args.gamma,
|
| | 'patience': 10,
|
| | 'verbose': True,
|
| | 'threshold': 0,
|
| | 'threshold_mode': 'abs',
|
| | 'cooldown': 10,
|
| | }
|
| |
|
| | self.kwargs_optimizer = kwargs_optimizer
|
| | self.scheduler_class = scheduler_class
|
| | self.kwargs_scheduler = kwargs_scheduler
|
| |
|
| | def _get_optimizer(model):
|
| |
|
| | class _Optimizer(optimizer_class):
|
| | def __init__(self, model, args, scheduler_class, kwargs_scheduler):
|
| | trainable = filter(lambda x: x.requires_grad, model.parameters())
|
| | super(_Optimizer, self).__init__(trainable, **kwargs_optimizer)
|
| |
|
| | self.args = args
|
| |
|
| | self._register_scheduler(scheduler_class, kwargs_scheduler)
|
| |
|
| | def _register_scheduler(self, scheduler_class, kwargs_scheduler):
|
| | self.scheduler = scheduler_class(self, **kwargs_scheduler)
|
| |
|
| | def schedule(self, metrics=None):
|
| | if isinstance(self, lrs.ReduceLROnPlateau):
|
| | self.scheduler.step(metrics)
|
| | else:
|
| | self.scheduler.step()
|
| |
|
| | def get_last_epoch(self):
|
| | return self.scheduler.last_epoch
|
| |
|
| | def get_lr(self):
|
| | return self.param_groups[0]['lr']
|
| |
|
| | def get_last_lr(self):
|
| | return self.scheduler.get_last_lr()[0]
|
| |
|
| | def state_dict(self):
|
| | state_dict = super(_Optimizer, self).state_dict()
|
| | state_dict['scheduler'] = self.scheduler.state_dict()
|
| |
|
| | return state_dict
|
| |
|
| | def load_state_dict(self, state_dict, epoch=None):
|
| |
|
| | super(_Optimizer, self).load_state_dict(state_dict)
|
| |
|
| | self.scheduler.load_state_dict(state_dict['scheduler'])
|
| |
|
| | reschedule = False
|
| | if isinstance(self.scheduler, lrs.MultiStepLR):
|
| | if self.args.milestones != list(self.scheduler.milestones) or self.args.gamma != self.scheduler.gamma:
|
| | reschedule = True
|
| |
|
| | if reschedule:
|
| | if epoch is None:
|
| | if self.scheduler.last_epoch > 1:
|
| | epoch = self.scheduler.last_epoch
|
| | else:
|
| | epoch = self.args.start_epoch - 1
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | self.scheduler.milestones = Counter(self.args.milestones)
|
| | self.scheduler.gamma = self.args.gamma
|
| | for i, group in enumerate(self.param_groups):
|
| | self.param_groups[i]['lr'] = group['initial_lr']
|
| | multiplier = 1
|
| | for milestone in self.scheduler.milestones:
|
| | if epoch >= milestone:
|
| | multiplier *= self.scheduler.gamma
|
| |
|
| | self.param_groups[i]['lr'] *= multiplier
|
| |
|
| | return _Optimizer(model, args, scheduler_class, kwargs_scheduler)
|
| |
|
| | self.G = _get_optimizer(model.G)
|
| | if model.D is not None:
|
| | self.D = _get_optimizer(model.D)
|
| | else:
|
| | self.D = None
|
| |
|
| | self.load(args.load_epoch)
|
| |
|
| | def zero_grad(self):
|
| | self.G.zero_grad()
|
| |
|
| | def step(self):
|
| | self.G.step()
|
| |
|
| | def schedule(self, metrics=None):
|
| | self.G.schedule(metrics)
|
| | if self.D is not None:
|
| | self.D.schedule(metrics)
|
| |
|
| | def get_last_epoch(self):
|
| | return self.G.get_last_epoch()
|
| |
|
| | def get_lr(self):
|
| | return self.G.get_lr()
|
| |
|
| | def get_last_lr(self):
|
| | return self.G.get_last_lr()
|
| |
|
| | def state_dict(self):
|
| | state_dict = Map()
|
| | state_dict.G = self.G.state_dict()
|
| | if self.D is not None:
|
| | state_dict.D = self.D.state_dict()
|
| |
|
| | return state_dict.toDict()
|
| |
|
| | def load_state_dict(self, state_dict, epoch=None):
|
| | state_dict = Map(**state_dict)
|
| | self.G.load_state_dict(state_dict.G, epoch)
|
| | if self.D is not None:
|
| | self.D.load_state_dict(state_dict.D, epoch)
|
| |
|
| | def _save_path(self, epoch=None):
|
| | epoch = epoch if epoch is not None else self.get_last_epoch()
|
| | save_path = os.path.join(self.save_dir, 'optim-{:d}.pt'.format(epoch))
|
| |
|
| | return save_path
|
| |
|
| | def save(self, epoch=None):
|
| | if epoch is None:
|
| | epoch = self.G.scheduler.last_epoch
|
| | torch.save(self.state_dict(), self._save_path(epoch))
|
| |
|
| | def load(self, epoch):
|
| | if epoch > 0:
|
| | print('Loading optimizer from {}'.format(self._save_path(epoch)))
|
| | self.load_state_dict(torch.load(self._save_path(epoch), map_location=self.args.device), epoch=epoch)
|
| |
|
| | elif epoch == 0:
|
| | pass
|
| | else:
|
| | raise NotImplementedError
|
| |
|
| | return
|
| |
|
| |
|