csun22's picture
Upload 59 files
ca1888b verified
#!/usr/bin/env python
"""
nn_manager
utilities used by nn_manager
"""
from __future__ import print_function
from collections import OrderedDict
import numpy as np
import torch
import core_scripts.other_tools.str_tools as nii_str_tk
import core_scripts.other_tools.display as nii_display
import core_scripts.nn_manager.nn_manager_conf as nii_nn_manage_conf
__author__ = "Xin Wang"
__email__ = "wangxin@nii.ac.jp"
__copyright__ = "Copyright 2020, Xin Wang"
#############################################################
def f_state_dict_wrapper(state_dict, data_parallel=False):
""" a wrapper to take care of state_dict when using DataParallism
f_model_load_wrapper(state_dict, data_parallel):
state_dict: pytorch state_dict
data_parallel: whether DataParallel is used
https://discuss.pytorch.org/t/solved-keyerror-unexpected-
key-module-encoder-embedding-weight-in-state-dict/1686/3
"""
if data_parallel is True:
# if data_parallel is used
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if not k.startswith('module'):
# if key is not starting with module, add it
name = 'module.' + k
else:
name = k
new_state_dict[name] = v
return new_state_dict
else:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if not k.startswith('module'):
name = k
else:
# remove module.
name = k[7:]
new_state_dict[name] = v
return new_state_dict
def f_process_loss(loss):
""" loss, loss_value = f_process_loss(loss):
Input:
loss: returned by loss_wrapper.compute
It can be a torch.tensor or a list of torch.tensor
When it is a list, it should look like:
[[loss_1, loss_2, loss_3],
[true/false, true/false, true.false]]
where true / false tells whether the loss should be taken into
consideration for early-stopping
Output:
loss: a torch.tensor
loss_value: a torch number of a list of torch number
"""
if type(loss) is list:
loss_sum = loss[0][0]
loss_list = [loss[0][0].item()]
if len(loss[0]) > 1:
for loss_tmp in loss[0][1:]:
loss_sum += loss_tmp
loss_list.append(loss_tmp.item())
return loss_sum, loss_list, loss[1]
else:
return loss, [loss.item()], [True]
def f_load_pretrained_model_partially(model, model_paths, model_name_prefix):
""" f_load_pretrained_model_partially(model, model_paths, model_name_prefix)
Initialize part of the model with pre-trained models
Input:
-----
model: torch model
model_paths: list of path to pre-trained models
model_prefix: list of model name prefix used by model
for example, pre_trained_model.*** may be referred to as
model.m_part1.*** in the new model. The prefix is "m_part1."
Output:
------
None
"""
if type(model_paths) is str:
model_path_tmp = [model_paths]
else:
model_path_tmp = model_paths
if type(model_name_prefix) is str:
model_prefix_tmp = [model_name_prefix]
else:
model_prefix_tmp = model_name_prefix
model_dict = model.state_dict()
for model_path, prefix in zip(model_path_tmp, model_prefix_tmp):
if prefix[-1] != '.':
# m_part1. not m_part
prefix += '.'
pretrained_dict = torch.load(model_path)
# 1. filter out unnecessary keys
pretrained_dict = {prefix + k: v \
for k, v in pretrained_dict.items() \
if prefix + k in model_dict}
print("Load model {:s} as {:s} ({:d} parameter buffers)".format(
model_path, prefix, len(pretrained_dict.keys())))
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
return
def f_save_epoch_name(args, epoch_idx, suffix=''):
""" str = f_save_epoch_name(args, epoch_idx)
Return the name of the model file saved during training
Args:
args: argument object by arg_parse, we will use
args.save_epoch_name, args.save_model_dir, args.save_model_ext
epoch_idx:, int, epoch index
suffix: a suffix to the name (default '')
Return:
str: name of epoch state file, str, e.g. epoch_001.pt
"""
tmp_name = "{}_{:03d}".format(args.save_epoch_name, epoch_idx) + suffix
return nii_str_tk.f_realpath(args.save_model_dir, tmp_name, \
args.save_model_ext)
def f_save_trained_name(args, suffix=''):
""" str = f_save_trained_name(args)
Return the name of the best trained model file
Args:
args: argument object by arg_parse
args.save_trained_name, args.save_model_dir, args.save_model_ext
suffix: a suffix added to the name (default '')
Return:
str: name of trained network file, e.g., trained_network.pt
"""
return nii_str_tk.f_realpath(
args.save_model_dir, args.save_trained_name + suffix,
args.save_model_ext)
def f_model_check(pt_model, model_type=None):
""" f_model_check(pt_model)
Check whether the model contains all the necessary keywords
Args:
----
pt_model: a Pytorch model
model_type_flag: str or None, a flag indicating the type of network
Return:
-------
"""
nii_display.f_print("Model check:")
if model_type in nii_nn_manage_conf.nn_model_keywords_bags:
keywords_bag = nii_nn_manage_conf.nn_model_keywords_bags[model_type]
else:
keywords_bag = nii_nn_manage_conf.nn_model_keywords_default
for tmpkey in keywords_bag.keys():
flag_mandatory, mes = keywords_bag[tmpkey]
# mandatory keywords
if flag_mandatory:
if not hasattr(pt_model, tmpkey):
nii_display.f_print("Please implement %s (%s)" % (tmpkey, mes))
nii_display.f_die("[Error]: found no %s in Model" % (tmpkey))
else:
print("[OK]: %s found" % (tmpkey))
else:
if not hasattr(pt_model, tmpkey):
print("[OK]: %s is ignored, %s" % (tmpkey, mes))
else:
print("[OK]: use %s, %s" % (tmpkey, mes))
# done
nii_display.f_print("Model check done\n")
return
def f_model_show(pt_model, do_model_def_check=True, model_type=None):
""" f_model_show(pt_model, do_model_check=True)
Print the informaiton of the model
Args:
pt_model, a Pytorch model
do_model_def_check, bool, whether check model definition (default True)
model_type: str or None (default None), what type of network
Return:
None
"""
if do_model_def_check:
f_model_check(pt_model, model_type)
nii_display.f_print("Model infor:")
print(pt_model)
num = sum(p.numel() for p in pt_model.parameters() if p.requires_grad)
nii_display.f_print("Parameter number: {:d}\n".format(num), "normal")
return
def f_loss_check(loss_module, model_type=None):
""" f_loss_check(pt_model)
Check whether the loss module contains all the necessary keywords
Args:
----
loss_module, a class
model_type, a str or None
Return:
-------
"""
nii_display.f_print("Loss check")
if model_type in nii_nn_manage_conf.loss_method_keywords_bags:
keywords_bag = nii_nn_manage_conf.loss_method_keywords_bags[model_type]
else:
keywords_bag = nii_nn_manage_conf.loss_method_keywords_default
for tmpkey in keywords_bag.keys():
flag_mandatory, mes = keywords_bag[tmpkey]
# mandatory keywords
if flag_mandatory:
if not hasattr(loss_module, tmpkey):
nii_display.f_print("Please implement %s (%s)" % (tmpkey, mes))
nii_display.f_die("[Error]: found no %s in Loss" % (tmpkey))
else:
# no need to print other information here
pass #print("[OK]: %s found" % (tmpkey))
else:
if not hasattr(loss_module, tmpkey):
# no need to print other information here
pass #print("[OK]: %s is ignored, %s" % (tmpkey, mes))
else:
print("[OK]: use %s, %s" % (tmpkey, mes))
# done
nii_display.f_print("Loss check done\n")
return
def f_loss_show(loss_module, do_loss_def_check=True, model_type=None):
""" f_model_show(pt_model, do_model_check=True)
Print the informaiton of the model
Args:
pt_model, a Pytorch model
do_model_def_check, bool, whether check model definition (default True)
model_type: str or None (default None), what type of network
Return:
None
"""
# no need to print other information here
# because loss is usually not a torch.Module
#nii_display.f_print("Loss infor:")
if do_loss_def_check:
f_loss_check(loss_module, model_type)
#print(loss_module)
return
if __name__ == "__main__":
print("nn_manager_tools")