LYL1015's picture
test
eea83e8
import torch.optim as optim
def get_optimizer(config, parameters):
if config.optim.optimizer == 'Adam':
optimizer = optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay,
betas=(0.9, 0.999), amsgrad=config.optim.amsgrad, eps=config.optim.eps)
elif config.optim.optimizer == 'RMSProp':
optimizer = optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
elif config.optim.optimizer == 'SGD':
optimizer = optim.SGD(parameters, lr=config.optim.lr, momentum=0.9)
else:
raise NotImplementedError('Optimizer {} not understood.'.format(config.optim.optimizer))
return optimizer