| |
| """ |
| nn_manager_gan |
| |
| A simple wrapper to run the training / testing process for GAN |
| |
| """ |
| from __future__ import print_function |
|
|
| import time |
| import datetime |
| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import core_scripts.data_io.conf as nii_dconf |
| import core_scripts.other_tools.display as nii_display |
| import core_scripts.other_tools.str_tools as nii_str_tk |
| import core_scripts.op_manager.op_process_monitor as nii_monitor |
| import core_scripts.op_manager.op_display_tools as nii_op_display_tk |
| import core_scripts.nn_manager.nn_manager_tools as nii_nn_tools |
| import core_scripts.nn_manager.nn_manager_conf as nii_nn_manage_conf |
| import core_scripts.other_tools.debug as nii_debug |
|
|
| __author__ = "Xin Wang" |
| __email__ = "wangxin@nii.ac.jp" |
| __copyright__ = "Copyright 2020, Xin Wang" |
|
|
| |
|
|
| def f_run_one_epoch_GAN( |
| args, pt_model_G, pt_model_D, |
| loss_wrapper, \ |
| device, monitor, \ |
| data_loader, epoch_idx, |
| optimizer_G = None, optimizer_D = None, \ |
| target_norm_method = None): |
| """ |
| f_run_one_epoch_GAN: |
| run one poech over the dataset (for training or validation sets) |
| |
| Args: |
| args: from argpase |
| pt_model_G: pytorch model (torch.nn.Module) generator |
| pt_model_D: pytorch model (torch.nn.Module) discriminator |
| loss_wrapper: a wrapper over loss function |
| loss_wrapper.compute(generated, target) |
| device: torch.device("cuda") or torch.device("cpu") |
| monitor: defined in op_procfess_monitor.py |
| data_loader: pytorch DataLoader. |
| epoch_idx: int, index of the current epoch |
| optimizer_G: torch optimizer or None, for generator |
| optimizer_D: torch optimizer or None, for discriminator |
| if None, the back propgation will be skipped |
| (for developlement set) |
| target_norm_method: method to normalize target data |
| (by default, use pt_model.normalize_target) |
| """ |
| |
| start_time = time.time() |
| |
| |
| for data_idx, (data_in, data_tar, data_info, idx_orig) in \ |
| enumerate(data_loader): |
|
|
| |
| |
| |
| |
| if optimizer_G is not None: |
| optimizer_G.zero_grad() |
| if optimizer_D is not None: |
| optimizer_D.zero_grad() |
| |
| |
| if isinstance(data_tar, torch.Tensor): |
| data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) |
| |
| |
| if target_norm_method is None: |
| normed_target = pt_model_G.normalize_target(data_tar) |
| else: |
| normed_target = target_norm_method(data_tar) |
| else: |
| nii_display.f_die("target data is required") |
|
|
| |
| |
| data_in = data_in.to(device, dtype=nii_dconf.d_dtype) |
| |
| |
| |
| |
| |
| |
| |
| pt_model_D.zero_grad() |
| d_out_real = pt_model_D(data_tar, data_in) |
| errD_real = loss_wrapper.compute_gan_D_real(d_out_real) |
| if optimizer_D is not None: |
| errD_real.backward() |
|
|
| |
| |
|
|
| |
| |
| |
| |
| if args.model_forward_with_target: |
| |
| |
| if isinstance(data_tar, torch.Tensor): |
| data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype) |
| if args.model_forward_with_file_name: |
| data_gen = pt_model_G(data_in, data_tar_tm, data_info) |
| else: |
| data_gen = pt_model_G(data_in, data_tar_tm) |
| else: |
| nii_display.f_print("--model-forward-with-target is set") |
| nii_display.f_die("but data_tar is not loaded") |
| else: |
| if args.model_forward_with_file_name: |
| |
| data_gen = pt_model_G(data_in, data_info) |
| else: |
| |
| data_gen = pt_model_G(data_in) |
| |
| |
| |
| |
| d_out_fake = pt_model_D(data_gen.detach(), data_in) |
| errD_fake = loss_wrapper.compute_gan_D_fake(d_out_fake) |
| if optimizer_D is not None: |
| errD_fake.backward() |
|
|
| |
| errD = errD_real + errD_fake |
| |
| |
| if optimizer_D is not None: |
| optimizer_D.step() |
|
|
| |
| |
| |
| pt_model_G.zero_grad() |
| d_out_fake_for_G = pt_model_D(data_gen, data_in) |
| errG_gan = loss_wrapper.compute_gan_G(d_out_fake_for_G) |
|
|
| |
| if hasattr(loss_wrapper, "compute_aux"): |
| errG_aux = loss_wrapper.compute_aux(data_gen, data_tar) |
| else: |
| errG_aux = torch.zeros_like(errG_gan) |
|
|
| |
| if hasattr(loss_wrapper, "compute_feat_match"): |
| errG_feat = loss_wrapper.compute_feat_match( |
| d_out_real, d_out_fake_for_G) |
| else: |
| errG_feat = torch.zeros_like(errG_gan) |
|
|
| |
| errG = errG_gan + errG_aux + errG_feat |
|
|
| if optimizer_G is not None: |
| errG.backward() |
| optimizer_G.step() |
| |
| |
| |
| loss_computed = [ |
| [errG_aux, errD_real, errD_fake, errG_gan, errG_feat], |
| [True, False, False, False, False]] |
| |
| |
| _, loss_vals, loss_flags = nii_nn_tools.f_process_loss(loss_computed) |
| |
| |
| end_time = time.time() |
| batchsize = len(data_info) |
| for idx, data_seq_info in enumerate(data_info): |
| |
| |
| |
| monitor.log_loss(loss_vals, loss_flags, \ |
| (end_time-start_time) / batchsize, \ |
| data_seq_info, idx_orig.numpy()[idx], \ |
| epoch_idx) |
| |
| if args.verbose == 1: |
| monitor.print_error_for_batch(data_idx*batchsize + idx,\ |
| idx_orig.numpy()[idx], \ |
| epoch_idx) |
| |
| |
| start_time = time.time() |
| |
| |
| return |
|
|
| def f_run_one_epoch_WGAN( |
| args, pt_model_G, pt_model_D, |
| loss_wrapper, \ |
| device, monitor, \ |
| data_loader, epoch_idx, |
| optimizer_G = None, optimizer_D = None, \ |
| target_norm_method = None): |
| """ |
| f_run_one_epoch_WGAN: |
| similar to f_run_one_epoch_GAN, but for WGAN |
| """ |
| |
| start_time = time.time() |
| |
| |
| |
| num_critic = 5 |
| |
| wgan_clamp = 0.01 |
|
|
| |
| for data_idx, (data_in, data_tar, data_info, idx_orig) in \ |
| enumerate(data_loader): |
| |
| |
| if optimizer_G is not None: |
| optimizer_G.zero_grad() |
| if optimizer_D is not None: |
| optimizer_D.zero_grad() |
| |
| |
| if isinstance(data_tar, torch.Tensor): |
| data_tar = data_tar.to(device, dtype=nii_dconf.d_dtype) |
| |
| |
| if target_norm_method is None: |
| normed_target = pt_model_G.normalize_target(data_tar) |
| else: |
| normed_target = target_norm_method(data_tar) |
| else: |
| nii_display.f_die("target data is required") |
|
|
| |
| |
| data_in = data_in.to(device, dtype=nii_dconf.d_dtype) |
| |
| |
| |
| |
| |
| pt_model_D.zero_grad() |
| d_out_real = pt_model_D(data_tar) |
| errD_real = loss_wrapper.compute_gan_D_real(d_out_real) |
| if optimizer_D is not None: |
| errD_real.backward() |
| d_out_real_mean = d_out_real.mean() |
|
|
| |
| |
| if args.model_forward_with_target: |
| |
| |
| if isinstance(data_tar, torch.Tensor): |
| data_tar_tm = data_tar.to(device, dtype=nii_dconf.d_dtype) |
| if args.model_forward_with_file_name: |
| data_gen = pt_model_G(data_in, data_tar_tm, data_info) |
| else: |
| data_gen = pt_model_G(data_in, data_tar_tm) |
| else: |
| nii_display.f_print("--model-forward-with-target is set") |
| nii_display.f_die("but data_tar is not loaded") |
| else: |
| if args.model_forward_with_file_name: |
| |
| data_gen = pt_model_G(data_in, data_info) |
| else: |
| |
| data_gen = pt_model_G(data_in) |
| |
| |
| |
| d_out_fake = pt_model_D(data_gen.detach()) |
| errD_fake = loss_wrapper.compute_gan_D_fake(d_out_fake) |
| if optimizer_D is not None: |
| errD_fake.backward() |
| d_out_fake_mean = d_out_fake.mean() |
| |
| errD = errD_real + errD_fake |
| if optimizer_D is not None: |
| optimizer_D.step() |
| |
| |
| for p in pt_model_D.parameters(): |
| p.data.clamp_(-wgan_clamp, wgan_clamp) |
|
|
| |
| |
| |
| pt_model_G.zero_grad() |
| d_out_fake_for_G = pt_model_D(data_gen) |
| errG_gan = loss_wrapper.compute_gan_G(d_out_fake_for_G) |
| errG_aux = loss_wrapper.compute_aux(data_gen, data_tar) |
| errG = errG_gan + errG_aux |
|
|
| |
| if data_idx % num_critic == 0 and optimizer_G is not None: |
| errG.backward() |
| optimizer_G.step() |
| |
| d_out_fake_for_G_mean = d_out_fake_for_G.mean() |
| |
| |
| |
| loss_computed = [[errG_aux, errG_gan, errD_real, errD_fake, |
| d_out_real_mean, d_out_fake_mean, |
| d_out_fake_for_G_mean], |
| [True, False, False, False, False, False, False]] |
| |
| |
| loss, loss_vals, loss_flags = nii_nn_tools.f_process_loss(loss_computed) |
|
|
| |
| |
| end_time = time.time() |
| batchsize = len(data_info) |
| for idx, data_seq_info in enumerate(data_info): |
| |
| |
| |
| monitor.log_loss(loss_vals, loss_flags, \ |
| (end_time-start_time) / batchsize, \ |
| data_seq_info, idx_orig.numpy()[idx], \ |
| epoch_idx) |
| |
| if args.verbose == 1: |
| monitor.print_error_for_batch(data_idx*batchsize + idx,\ |
| idx_orig.numpy()[idx], \ |
| epoch_idx) |
| |
| |
| start_time = time.time() |
| |
| |
| return |
|
|
|
|
| def f_train_wrapper_GAN( |
| args, pt_model_G, pt_model_D, loss_wrapper, device, \ |
| optimizer_G_wrapper, optimizer_D_wrapper, \ |
| train_dataset_wrapper, \ |
| val_dataset_wrapper = None, \ |
| checkpoint_G = None, checkpoint_D = None): |
| """ |
| f_train_wrapper_GAN( |
| args, pt_model_G, pt_model_D, loss_wrapper, device, |
| optimizer_G_wrapper, optimizer_D_wrapper, |
| train_dataset_wrapper, val_dataset_wrapper = None, |
| check_point = None): |
| |
| A wrapper to run the training process |
| |
| Args: |
| args: argument information given by argpase |
| pt_model_G: generator, pytorch model (torch.nn.Module) |
| pt_model_D: discriminator, pytorch model (torch.nn.Module) |
| loss_wrapper: a wrapper over loss functions |
| loss_wrapper.compute_D_real(discriminator_output) |
| loss_wrapper.compute_D_fake(discriminator_output) |
| loss_wrapper.compute_G(discriminator_output) |
| loss_wrapper.compute_G(fake, real) |
| |
| device: torch.device("cuda") or torch.device("cpu") |
| |
| optimizer_G_wrapper: |
| a optimizer wrapper for generator (defined in op_manager.py) |
| optimizer_D_wrapper: |
| a optimizer wrapper for discriminator (defined in op_manager.py) |
| |
| train_dataset_wrapper: |
| a wrapper over training data set (data_io/default_data_io.py) |
| train_dataset_wrapper.get_loader() returns torch.DataSetLoader |
| |
| val_dataset_wrapper: |
| a wrapper over validation data set (data_io/default_data_io.py) |
| it can None. |
| |
| checkpoint_G: |
| a check_point that stores every thing to resume training |
| |
| checkpoint_D: |
| a check_point that stores every thing to resume training |
| """ |
| |
| nii_display.f_print_w_date("Start model training") |
|
|
| |
| |
| |
|
|
| |
| optimizer_G_wrapper.print_info() |
| optimizer_D_wrapper.print_info() |
| optimizer_G = optimizer_G_wrapper.optimizer |
| optimizer_D = optimizer_D_wrapper.optimizer |
| epoch_num = optimizer_G_wrapper.get_epoch_num() |
| no_best_epoch_num = optimizer_G_wrapper.get_no_best_epoch_num() |
| |
| |
| train_dataset_wrapper.print_info() |
| train_data_loader = train_dataset_wrapper.get_loader() |
| train_seq_num = train_dataset_wrapper.get_seq_num() |
|
|
| |
| monitor_trn = nii_monitor.Monitor(epoch_num, train_seq_num) |
|
|
| |
| if val_dataset_wrapper is not None: |
| val_dataset_wrapper.print_info() |
| val_data_loader = val_dataset_wrapper.get_loader() |
| val_seq_num = val_dataset_wrapper.get_seq_num() |
| monitor_val = nii_monitor.Monitor(epoch_num, val_seq_num) |
| else: |
| monitor_val = None |
|
|
| |
| train_log = '' |
| model_tags = ["_G", "_D"] |
|
|
| |
| |
| if torch.cuda.device_count() > 1 and args.multi_gpu_data_parallel: |
| nii_display.f_die("data_parallel not implemented for GAN") |
| else: |
| nii_display.f_print("\nUse single GPU: %s\n" % \ |
| (torch.cuda.get_device_name(device))) |
| flag_multi_device = False |
| normtarget_f = None |
|
|
| pt_model_G.to(device, dtype=nii_dconf.d_dtype) |
| pt_model_D.to(device, dtype=nii_dconf.d_dtype) |
|
|
| |
| nii_display.f_print("Setup generator") |
| nii_nn_tools.f_model_show(pt_model_G, model_type='GAN') |
| nii_display.f_print("Setup discriminator") |
| nii_nn_tools.f_model_show(pt_model_D, do_model_def_check=False, |
| model_type='GAN') |
| nii_nn_tools.f_loss_show(loss_wrapper, model_type='GAN') |
|
|
| |
| |
| |
|
|
| |
| cp_names = nii_nn_manage_conf.CheckPointKey() |
| if checkpoint_G is not None or checkpoint_D is not None: |
| for checkpoint, optimizer, pt_model, model_name in \ |
| zip([checkpoint_G, checkpoint_D], [optimizer_G, optimizer_D], |
| [pt_model_G, pt_model_D], ["Generator", "Discriminator"]): |
| nii_display.f_print("For %s" % (model_name)) |
| if type(checkpoint) is dict: |
| |
| |
| if cp_names.state_dict in checkpoint: |
| |
| |
| pt_model.load_state_dict( |
| nii_nn_tools.f_state_dict_wrapper( |
| checkpoint[cp_names.state_dict], |
| flag_multi_device)) |
| |
| if cp_names.optimizer in checkpoint: |
| optimizer.load_state_dict(checkpoint[cp_names.optimizer]) |
| |
| if not args.ignore_training_history_in_trained_model: |
| |
| if cp_names.trnlog in checkpoint: |
| monitor_trn.load_state_dic( |
| checkpoint[cp_names.trnlog]) |
| if cp_names.vallog in checkpoint and monitor_val: |
| monitor_val.load_state_dic( |
| checkpoint[cp_names.vallog]) |
| if cp_names.info in checkpoint: |
| train_log = checkpoint[cp_names.info] |
| nii_display.f_print("Load check point, resume training") |
| else: |
| nii_display.f_print("Load pretrained model and optimizer") |
| elif checkpoint is not None: |
| |
| |
| pt_model.load_state_dict( |
| nii_nn_tools.f_state_dict_wrapper( |
| checkpoint, flag_multi_device)) |
| nii_display.f_print("Load pretrained model") |
| else: |
| nii_display.f_print("No pretrained model") |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| flag_early_stopped = False |
| start_epoch = monitor_trn.get_epoch() |
| epoch_num = monitor_trn.get_max_epoch() |
|
|
| |
| if hasattr(loss_wrapper, "flag_wgan") and loss_wrapper.flag_wgan: |
| f_wrapper_gan_one_epoch = f_run_one_epoch_WGAN |
| else: |
| f_wrapper_gan_one_epoch = f_run_one_epoch_GAN |
|
|
| |
| _ = nii_op_display_tk.print_log_head() |
| nii_display.f_print_message(train_log, flush=True, end='') |
| |
| |
| |
| for epoch_idx in range(start_epoch, epoch_num): |
|
|
| |
| pt_model_D.train() |
| pt_model_G.train() |
| |
| f_wrapper_gan_one_epoch( |
| args, pt_model_G, pt_model_D, |
| loss_wrapper, device, \ |
| monitor_trn, train_data_loader, \ |
| epoch_idx, optimizer_G, optimizer_D, |
| normtarget_f) |
|
|
| time_trn = monitor_trn.get_time(epoch_idx) |
| loss_trn = monitor_trn.get_loss(epoch_idx) |
| |
| |
| if val_dataset_wrapper is not None: |
| |
| if args.eval_mode_for_validation: |
| pt_model_G.eval() |
| pt_model_D.eval() |
| with torch.no_grad(): |
| f_wrapper_gan_one_epoch( |
| args, pt_model_G, pt_model_D, |
| loss_wrapper, \ |
| device, \ |
| monitor_val, val_data_loader, \ |
| epoch_idx, None, None, normtarget_f) |
| time_val = monitor_val.get_time(epoch_idx) |
| loss_val = monitor_val.get_loss(epoch_idx) |
| else: |
| time_val, loss_val = 0, 0 |
| |
| |
| if val_dataset_wrapper is not None: |
| flag_new_best = monitor_val.is_new_best() |
| else: |
| flag_new_best = True |
| |
| |
| train_log += nii_op_display_tk.print_train_info( |
| epoch_idx, time_trn, loss_trn, time_val, loss_val, |
| flag_new_best, optimizer_G_wrapper.get_lr_info()) |
|
|
| |
| if flag_new_best: |
| for pt_model, tmp_tag in zip([pt_model_G, pt_model_D], model_tags): |
| tmp_best_name = nii_nn_tools.f_save_trained_name(args, tmp_tag) |
| torch.save(pt_model.state_dict(), tmp_best_name) |
| |
| |
| if not args.not_save_each_epoch: |
| |
| for pt_model, optimizer, model_tag in \ |
| zip([pt_model_G, pt_model_D], [optimizer_G, optimizer_D], |
| model_tags): |
|
|
| tmp_model_name = nii_nn_tools.f_save_epoch_name( |
| args, epoch_idx, model_tag) |
| if monitor_val is not None: |
| tmp_val_log = monitor_val.get_state_dic() |
| else: |
| tmp_val_log = None |
| |
| tmp_dic = { |
| cp_names.state_dict : pt_model.state_dict(), |
| cp_names.info : train_log, |
| cp_names.optimizer : optimizer.state_dict(), |
| cp_names.trnlog : monitor_trn.get_state_dic(), |
| cp_names.vallog : tmp_val_log |
| } |
| torch.save(tmp_dic, tmp_model_name) |
| if args.verbose == 1: |
| nii_display.f_eprint(str(datetime.datetime.now())) |
| nii_display.f_eprint("Save {:s}".format(tmp_model_name), |
| flush=True) |
| |
| |
| if monitor_val is not None and \ |
| monitor_val.should_early_stop(no_best_epoch_num): |
| flag_early_stopped = True |
| break |
| |
| |
|
|
| nii_op_display_tk.print_log_tail() |
| if flag_early_stopped: |
| nii_display.f_print("Training finished by early stopping") |
| else: |
| nii_display.f_print("Training finished") |
| nii_display.f_print("Model is saved to", end = '') |
| for model_tag in model_tags: |
| nii_display.f_print("{}".format( |
| nii_nn_tools.f_save_trained_name(args, model_tag))) |
| return |
|
|
| |
| if __name__ == "__main__": |
| print("nn_manager for GAN") |
|
|