Thesis / utils /solver /optimizer.py
Ryan-Pham's picture
Upload 103 files
beb7843 verified
import torch
from torch import optim
def build_optimizer(cfg, model, base_lr=0.0, resume=None):
print('==============================')
print('Optimizer: {}'.format(cfg['optimizer']))
print('--momentum: {}'.format(cfg['momentum']))
print('--weight_decay: {}'.format(cfg['weight_decay']))
if cfg['optimizer'] == 'sgd':
optimizer = optim.SGD(
model.parameters(),
lr=base_lr,
momentum=cfg['momentum'],
weight_decay=cfg['weight_decay'])
elif cfg['optimizer'] == 'adam':
optimizer = optim.Adam(
model.parameters(),
lr=base_lr,
eight_decay=cfg['weight_decay'])
elif cfg['optimizer'] == 'adamw':
optimizer = optim.AdamW(
model.parameters(),
lr=base_lr,
weight_decay=cfg['weight_decay'])
start_epoch = 0
if resume is not None:
print('keep training: ', resume)
checkpoint = torch.load(resume)
# checkpoint state dict
checkpoint_state_dict = checkpoint.pop("optimizer")
optimizer.load_state_dict(checkpoint_state_dict)
start_epoch = checkpoint.pop("epoch")
return optimizer, start_epoch