| | |
| | """ |
| | nn_manager |
| | |
| | utilities used by nn_manager |
| | |
| | """ |
| | from __future__ import print_function |
| |
|
| | from collections import OrderedDict |
| | import numpy as np |
| | import torch |
| |
|
| | import core_scripts.other_tools.str_tools as nii_str_tk |
| | import core_scripts.other_tools.display as nii_display |
| | import core_scripts.nn_manager.nn_manager_conf as nii_nn_manage_conf |
| |
|
| | __author__ = "Xin Wang" |
| | __email__ = "wangxin@nii.ac.jp" |
| | __copyright__ = "Copyright 2020, Xin Wang" |
| |
|
| | |
| |
|
| | def f_state_dict_wrapper(state_dict, data_parallel=False): |
| | """ a wrapper to take care of state_dict when using DataParallism |
| | |
| | f_model_load_wrapper(state_dict, data_parallel): |
| | state_dict: pytorch state_dict |
| | data_parallel: whether DataParallel is used |
| | |
| | https://discuss.pytorch.org/t/solved-keyerror-unexpected- |
| | key-module-encoder-embedding-weight-in-state-dict/1686/3 |
| | """ |
| | if data_parallel is True: |
| | |
| | new_state_dict = OrderedDict() |
| | for k, v in state_dict.items(): |
| | if not k.startswith('module'): |
| | |
| | name = 'module.' + k |
| | else: |
| | name = k |
| | new_state_dict[name] = v |
| | return new_state_dict |
| | else: |
| | new_state_dict = OrderedDict() |
| | for k, v in state_dict.items(): |
| | if not k.startswith('module'): |
| | name = k |
| | else: |
| | |
| | name = k[7:] |
| | new_state_dict[name] = v |
| | return new_state_dict |
| |
|
| | def f_process_loss(loss): |
| | """ loss, loss_value = f_process_loss(loss): |
| | Input: |
| | loss: returned by loss_wrapper.compute |
| | It can be a torch.tensor or a list of torch.tensor |
| | When it is a list, it should look like: |
| | [[loss_1, loss_2, loss_3], |
| | [true/false, true/false, true.false]] |
| | where true / false tells whether the loss should be taken into |
| | consideration for early-stopping |
| | |
| | Output: |
| | loss: a torch.tensor |
| | loss_value: a torch number of a list of torch number |
| | """ |
| | if type(loss) is list: |
| | loss_sum = loss[0][0] |
| | loss_list = [loss[0][0].item()] |
| | if len(loss[0]) > 1: |
| | for loss_tmp in loss[0][1:]: |
| | loss_sum += loss_tmp |
| | loss_list.append(loss_tmp.item()) |
| | return loss_sum, loss_list, loss[1] |
| | else: |
| | return loss, [loss.item()], [True] |
| |
|
| |
|
| | def f_load_pretrained_model_partially(model, model_paths, model_name_prefix): |
| | """ f_load_pretrained_model_partially(model, model_paths, model_name_prefix) |
| | |
| | Initialize part of the model with pre-trained models |
| | |
| | Input: |
| | ----- |
| | model: torch model |
| | model_paths: list of path to pre-trained models |
| | model_prefix: list of model name prefix used by model |
| | for example, pre_trained_model.*** may be referred to as |
| | model.m_part1.*** in the new model. The prefix is "m_part1." |
| | |
| | Output: |
| | ------ |
| | None |
| | """ |
| | if type(model_paths) is str: |
| | model_path_tmp = [model_paths] |
| | else: |
| | model_path_tmp = model_paths |
| | if type(model_name_prefix) is str: |
| | model_prefix_tmp = [model_name_prefix] |
| | else: |
| | model_prefix_tmp = model_name_prefix |
| |
|
| | model_dict = model.state_dict() |
| |
|
| | for model_path, prefix in zip(model_path_tmp, model_prefix_tmp): |
| | if prefix[-1] != '.': |
| | |
| | prefix += '.' |
| | |
| | pretrained_dict = torch.load(model_path) |
| | |
| | |
| | pretrained_dict = {prefix + k: v \ |
| | for k, v in pretrained_dict.items() \ |
| | if prefix + k in model_dict} |
| | print("Load model {:s} as {:s} ({:d} parameter buffers)".format( |
| | model_path, prefix, len(pretrained_dict.keys()))) |
| | |
| | |
| | model_dict.update(pretrained_dict) |
| | |
| | |
| | model.load_state_dict(model_dict) |
| | return |
| |
|
| | def f_save_epoch_name(args, epoch_idx, suffix=''): |
| | """ str = f_save_epoch_name(args, epoch_idx) |
| | Return the name of the model file saved during training |
| | |
| | Args: |
| | args: argument object by arg_parse, we will use |
| | args.save_epoch_name, args.save_model_dir, args.save_model_ext |
| | epoch_idx:, int, epoch index |
| | suffix: a suffix to the name (default '') |
| | |
| | Return: |
| | str: name of epoch state file, str, e.g. epoch_001.pt |
| | """ |
| | tmp_name = "{}_{:03d}".format(args.save_epoch_name, epoch_idx) + suffix |
| | return nii_str_tk.f_realpath(args.save_model_dir, tmp_name, \ |
| | args.save_model_ext) |
| |
|
| | def f_save_trained_name(args, suffix=''): |
| | """ str = f_save_trained_name(args) |
| | Return the name of the best trained model file |
| | |
| | Args: |
| | args: argument object by arg_parse |
| | args.save_trained_name, args.save_model_dir, args.save_model_ext |
| | suffix: a suffix added to the name (default '') |
| | |
| | Return: |
| | str: name of trained network file, e.g., trained_network.pt |
| | """ |
| | return nii_str_tk.f_realpath( |
| | args.save_model_dir, args.save_trained_name + suffix, |
| | args.save_model_ext) |
| |
|
| |
|
| | def f_model_check(pt_model, model_type=None): |
| | """ f_model_check(pt_model) |
| | Check whether the model contains all the necessary keywords |
| | |
| | Args: |
| | ---- |
| | pt_model: a Pytorch model |
| | model_type_flag: str or None, a flag indicating the type of network |
| | |
| | Return: |
| | ------- |
| | """ |
| | nii_display.f_print("Model check:") |
| | if model_type in nii_nn_manage_conf.nn_model_keywords_bags: |
| | keywords_bag = nii_nn_manage_conf.nn_model_keywords_bags[model_type] |
| | else: |
| | keywords_bag = nii_nn_manage_conf.nn_model_keywords_default |
| | |
| | for tmpkey in keywords_bag.keys(): |
| | flag_mandatory, mes = keywords_bag[tmpkey] |
| |
|
| | |
| | if flag_mandatory: |
| | if not hasattr(pt_model, tmpkey): |
| | nii_display.f_print("Please implement %s (%s)" % (tmpkey, mes)) |
| | nii_display.f_die("[Error]: found no %s in Model" % (tmpkey)) |
| | else: |
| | print("[OK]: %s found" % (tmpkey)) |
| | else: |
| | if not hasattr(pt_model, tmpkey): |
| | print("[OK]: %s is ignored, %s" % (tmpkey, mes)) |
| | else: |
| | print("[OK]: use %s, %s" % (tmpkey, mes)) |
| | |
| | nii_display.f_print("Model check done\n") |
| | return |
| |
|
| | def f_model_show(pt_model, do_model_def_check=True, model_type=None): |
| | """ f_model_show(pt_model, do_model_check=True) |
| | Print the informaiton of the model |
| | |
| | Args: |
| | pt_model, a Pytorch model |
| | do_model_def_check, bool, whether check model definition (default True) |
| | model_type: str or None (default None), what type of network |
| | |
| | Return: |
| | None |
| | """ |
| | if do_model_def_check: |
| | f_model_check(pt_model, model_type) |
| |
|
| | nii_display.f_print("Model infor:") |
| | print(pt_model) |
| | num = sum(p.numel() for p in pt_model.parameters() if p.requires_grad) |
| | nii_display.f_print("Parameter number: {:d}\n".format(num), "normal") |
| | return |
| |
|
| |
|
| | def f_loss_check(loss_module, model_type=None): |
| | """ f_loss_check(pt_model) |
| | Check whether the loss module contains all the necessary keywords |
| | |
| | Args: |
| | ---- |
| | loss_module, a class |
| | model_type, a str or None |
| | Return: |
| | ------- |
| | """ |
| | nii_display.f_print("Loss check") |
| | |
| | if model_type in nii_nn_manage_conf.loss_method_keywords_bags: |
| | keywords_bag = nii_nn_manage_conf.loss_method_keywords_bags[model_type] |
| | else: |
| | keywords_bag = nii_nn_manage_conf.loss_method_keywords_default |
| |
|
| | for tmpkey in keywords_bag.keys(): |
| | flag_mandatory, mes = keywords_bag[tmpkey] |
| |
|
| | |
| | if flag_mandatory: |
| | if not hasattr(loss_module, tmpkey): |
| | nii_display.f_print("Please implement %s (%s)" % (tmpkey, mes)) |
| | nii_display.f_die("[Error]: found no %s in Loss" % (tmpkey)) |
| | else: |
| | |
| | pass |
| | else: |
| | if not hasattr(loss_module, tmpkey): |
| | |
| | pass |
| | else: |
| | print("[OK]: use %s, %s" % (tmpkey, mes)) |
| | |
| | nii_display.f_print("Loss check done\n") |
| | return |
| |
|
| | def f_loss_show(loss_module, do_loss_def_check=True, model_type=None): |
| | """ f_model_show(pt_model, do_model_check=True) |
| | Print the informaiton of the model |
| | |
| | Args: |
| | pt_model, a Pytorch model |
| | do_model_def_check, bool, whether check model definition (default True) |
| | model_type: str or None (default None), what type of network |
| | |
| | Return: |
| | None |
| | """ |
| | |
| | |
| |
|
| | |
| | if do_loss_def_check: |
| | f_loss_check(loss_module, model_type) |
| | |
| | return |
| |
|
| | if __name__ == "__main__": |
| | print("nn_manager_tools") |
| |
|