hyliu's picture
Upload folder using huggingface_hub
8ec10cf verified
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
import os
from collections import Counter
from model import Model
from utils import interact, Map
class Optimizer(object):
def __init__(self, args, model):
self.args = args
self.save_dir = os.path.join(self.args.save_dir, 'optim')
os.makedirs(self.save_dir, exist_ok=True)
if isinstance(model, Model):
model = model.model
# set base arguments
kwargs_optimizer = {
'lr': args.lr,
'weight_decay': args.weight_decay
}
if args.optimizer == 'SGD':
optimizer_class = optim.SGD
kwargs_optimizer['momentum'] = args.momentum
elif args.optimizer == 'ADAM':
optimizer_class = optim.Adam
kwargs_optimizer['betas'] = args.betas
kwargs_optimizer['eps'] = args.epsilon
elif args.optimizer == 'RMSPROP':
optimizer_class = optim.RMSprop
kwargs_optimizer['eps'] = args.epsilon
# scheduler
if args.scheduler == 'step':
scheduler_class = lrs.MultiStepLR
kwargs_scheduler = {
'milestones': args.milestones,
'gamma': args.gamma,
}
elif args.scheduler == 'plateau':
scheduler_class = lrs.ReduceLROnPlateau
kwargs_scheduler = {
'mode': 'min',
'factor': args.gamma,
'patience': 10,
'verbose': True,
'threshold': 0,
'threshold_mode': 'abs',
'cooldown': 10,
}
self.kwargs_optimizer = kwargs_optimizer
self.scheduler_class = scheduler_class
self.kwargs_scheduler = kwargs_scheduler
def _get_optimizer(model):
class _Optimizer(optimizer_class):
def __init__(self, model, args, scheduler_class, kwargs_scheduler):
trainable = filter(lambda x: x.requires_grad, model.parameters())
super(_Optimizer, self).__init__(trainable, **kwargs_optimizer)
self.args = args
self._register_scheduler(scheduler_class, kwargs_scheduler)
def _register_scheduler(self, scheduler_class, kwargs_scheduler):
self.scheduler = scheduler_class(self, **kwargs_scheduler)
def schedule(self, metrics=None):
if isinstance(self, lrs.ReduceLROnPlateau):
self.scheduler.step(metrics)
else:
self.scheduler.step()
def get_last_epoch(self):
return self.scheduler.last_epoch
def get_lr(self):
return self.param_groups[0]['lr']
def get_last_lr(self):
return self.scheduler.get_last_lr()[0]
def state_dict(self):
state_dict = super(_Optimizer, self).state_dict() # {'state': ..., 'param_groups': ...}
state_dict['scheduler'] = self.scheduler.state_dict()
return state_dict
def load_state_dict(self, state_dict, epoch=None):
# optimizer
super(_Optimizer, self).load_state_dict(state_dict) # load 'state' and 'param_groups' only
# scheduler
self.scheduler.load_state_dict(state_dict['scheduler']) # should work for plateau or simple resuming
reschedule = False
if isinstance(self.scheduler, lrs.MultiStepLR):
if self.args.milestones != list(self.scheduler.milestones) or self.args.gamma != self.scheduler.gamma:
reschedule = True
if reschedule:
if epoch is None:
if self.scheduler.last_epoch > 1:
epoch = self.scheduler.last_epoch
else:
epoch = self.args.start_epoch - 1
# if False:
# # option 1. new scheduler
# for i, group in enumerate(self.param_groups):
# self.param_groups[i]['lr'] = group['initial_lr'] # reset optimizer learning rate to initial
# # self.scheduler = None
# self._register_scheduler(scheduler_class, kwargs_scheduler)
# self.zero_grad()
# self.step()
# for _ in range(epoch):
# self.scheduler.step()
# self._step_count -= 1
# else:
# option 2. modify existing scheduler
self.scheduler.milestones = Counter(self.args.milestones)
self.scheduler.gamma = self.args.gamma
for i, group in enumerate(self.param_groups):
self.param_groups[i]['lr'] = group['initial_lr'] # reset optimizer learning rate to initial
multiplier = 1
for milestone in self.scheduler.milestones:
if epoch >= milestone:
multiplier *= self.scheduler.gamma
self.param_groups[i]['lr'] *= multiplier
return _Optimizer(model, args, scheduler_class, kwargs_scheduler)
self.G = _get_optimizer(model.G)
if model.D is not None:
self.D = _get_optimizer(model.D)
else:
self.D = None
self.load(args.load_epoch)
def zero_grad(self):
self.G.zero_grad()
def step(self):
self.G.step()
def schedule(self, metrics=None):
self.G.schedule(metrics)
if self.D is not None:
self.D.schedule(metrics)
def get_last_epoch(self):
return self.G.get_last_epoch()
def get_lr(self):
return self.G.get_lr()
def get_last_lr(self):
return self.G.get_last_lr()
def state_dict(self):
state_dict = Map()
state_dict.G = self.G.state_dict()
if self.D is not None:
state_dict.D = self.D.state_dict()
return state_dict.toDict()
def load_state_dict(self, state_dict, epoch=None):
state_dict = Map(**state_dict)
self.G.load_state_dict(state_dict.G, epoch)
if self.D is not None:
self.D.load_state_dict(state_dict.D, epoch)
def _save_path(self, epoch=None):
epoch = epoch if epoch is not None else self.get_last_epoch()
save_path = os.path.join(self.save_dir, 'optim-{:d}.pt'.format(epoch))
return save_path
def save(self, epoch=None):
if epoch is None:
epoch = self.G.scheduler.last_epoch
torch.save(self.state_dict(), self._save_path(epoch))
def load(self, epoch):
if epoch > 0:
print('Loading optimizer from {}'.format(self._save_path(epoch)))
self.load_state_dict(torch.load(self._save_path(epoch), map_location=self.args.device), epoch=epoch)
elif epoch == 0:
pass
else:
raise NotImplementedError
return