| |
|
| | import os
|
| | from torchvision import transforms, utils
|
| | import torch
|
| |
|
| | import numpy as np
|
| | from utils.utils_file import get_cur_time_stamp, create_folder
|
| |
|
| | def compute_differentiable_params(net):
|
| | return sum(p.numel() for p in net.parameters() if p.requires_grad)
|
| |
|
| | def convert_Relight_latent_light(latent_feature):
|
| | """ Convert n x 6 x 16 x 16 -> n x 3 x 16 x 32 """
|
| |
|
| | batch_size, C, H, W = latent_feature.size()
|
| | latent_feature = torch.reshape(latent_feature, (batch_size, 3, 16, 32))
|
| |
|
| | return latent_feature
|
| |
|
| | def show_batch(sample_batch, out_file=None):
|
| | grid = utils.make_grid(sample_batch)
|
| | plt.figure(figsize=(30,20))
|
| | plt.imshow(grid.detach().cpu().numpy().transpose((1,2,0)))
|
| |
|
| | if not out_file is None:
|
| | print('try save ', out_file)
|
| | plt.savefig(out_file)
|
| |
|
| | plt.show()
|
| |
|
| | def show_light_batch(light_batch):
|
| | light_batch = convert_Relight_latent_light(light_batch)
|
| | show_batch(light_batch)
|
| |
|
| | def save_loss(figure_fname, train_loss, valid_loss):
|
| | plt.plot(train_loss)
|
| | plt.plot(valid_loss)
|
| | plt.legend(['train_loss', 'valid_loss'])
|
| | plt.savefig(figure_fname)
|
| |
|
| | def save_model(output_folder, model, optimizer, epoch, best_loss, fname, hist_train_loss, hist_valid_loss, hist_lr, params):
|
| | """ Save current best model into some folder """
|
| | create_folder(output_folder)
|
| |
|
| |
|
| |
|
| | output_fname = os.path.join(output_folder, fname)
|
| | tmp_model = model
|
| | if params.multi_gpu and hasattr(tmp_model, 'module'):
|
| | tmp_model = model.module
|
| |
|
| | torch.save({
|
| | 'epoch': epoch,
|
| | 'best_loss': best_loss,
|
| | 'model_state_dict': tmp_model.state_dict(),
|
| | 'optimizer_state_dict': optimizer.state_dict(),
|
| | 'hist_train_loss': hist_train_loss,
|
| | 'hist_valid_loss': hist_valid_loss,
|
| | 'hist_lr':hist_lr,
|
| | 'params':str(params)
|
| | }, output_fname)
|
| | return output_fname
|
| |
|
| | def get_lr(optimizer):
|
| | for param_group in optimizer.param_groups:
|
| | return param_group['lr']
|
| |
|
| | def set_lr(optimizer, lr):
|
| | for param_group in optimizer.param_groups:
|
| | param_group['lr'] = lr
|
| |
|