csun22's picture
Upload 59 files
ca1888b verified
#!/usr/bin/env python
"""
op_manager
A simple wrapper to create optimizer
"""
from __future__ import absolute_import
import os
import sys
import numpy as np
import torch
import torch.optim as torch_optim
import torch.optim.lr_scheduler as torch_optim_steplr
import core_scripts.other_tools.list_tools as nii_list_tools
import core_scripts.other_tools.display as nii_warn
import core_scripts.other_tools.str_tools as nii_str_tk
import core_scripts.op_manager.conf as nii_op_config
import core_scripts.op_manager.op_process_monitor as nii_op_monitor
import core_scripts.op_manager.lr_scheduler as nii_lr_scheduler
__author__ = "Xin Wang"
__email__ = "wangxin@nii.ac.jp"
__copyright__ = "Copyright 2020, Xin Wang"
class OptimizerWrapper():
""" Wrapper over optimizer
"""
def __init__(self, model, args):
""" Initialize an optimizer over model.parameters()
"""
# check valildity of model
if not hasattr(model, "parameters"):
nii_warn.f_print("model is not torch.nn", "error")
nii_warn.f_die("Error in creating OptimizerWrapper")
# set optimizer type
self.op_flag = args.optimizer
self.lr = args.lr
self.l2_penalty = args.l2_penalty
# grad clip norm is directly added in nn_manager
self.grad_clip_norm = args.grad_clip_norm
# create optimizer
if self.op_flag == "Adam":
if self.l2_penalty > 0:
self.optimizer = torch_optim.Adam(model.parameters(),
lr=self.lr,
weight_decay=self.l2_penalty)
else:
self.optimizer = torch_optim.Adam(model.parameters(),
lr=self.lr)
else:
nii_warn.f_print("%s not availabel" % (self.op_flag),
"error")
nii_warn.f_die("Please change optimizer")
# number of epochs
self.epochs = args.epochs
self.no_best_epochs = args.no_best_epochs
# lr scheduler
self.lr_scheduler = nii_lr_scheduler.LRScheduler(self.optimizer, args)
return
def print_info(self):
""" print message of optimizer
"""
mes = "Optimizer:\n Type: {} ".format(self.op_flag)
mes += "\n Learing rate: {:2.6f}".format(self.lr)
mes += "\n Epochs: {:d}".format(self.epochs)
mes += "\n No-best-epochs: {:d}".format(self.no_best_epochs)
if self.lr_scheduler.f_valid():
mes += self.lr_scheduler.f_print_info()
if self.l2_penalty > 0:
mes += "\n With weight penalty {:f}".format(self.l2_penalty)
if self.grad_clip_norm > 0:
mes += "\n With grad clip norm {:f}".format(self.grad_clip_norm)
nii_warn.f_print_message(mes)
def get_epoch_num(self):
return self.epochs
def get_no_best_epoch_num(self):
return self.no_best_epochs
def get_lr_info(self):
if self.lr_scheduler.f_valid():
# no way to look into the updated lr rather than using _last_lr
tmp = ''
for updated_lr in self.lr_scheduler.f_last_lr():
if np.abs(self.lr - updated_lr) > 0.0000001:
tmp += "{:.2e} ".format(updated_lr)
if tmp:
tmp = " LR -> " + tmp
return tmp
else:
return None
if __name__ == "__main__":
print("Optimizer Wrapper")