File size: 2,445 Bytes
cda88e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#import matplotlib.pyplot as plt
import os
from torchvision import transforms, utils
import torch
#import matplotlib.pyplot as plt
import numpy as np
from utils.utils_file import get_cur_time_stamp, create_folder

def compute_differentiable_params(net):
    return sum(p.numel() for p in net.parameters() if p.requires_grad)

def convert_Relight_latent_light(latent_feature):
    """ Convert n x 6 x 16 x 16 -> n x 3 x 16 x 32 """
    # torch image: C X H X W
    batch_size, C, H, W = latent_feature.size()
    latent_feature = torch.reshape(latent_feature, (batch_size, 3, 16, 32))  # make sure it is right
    # print(latent_feature.size())
    return latent_feature

def show_batch(sample_batch, out_file=None):
    grid = utils.make_grid(sample_batch)
    plt.figure(figsize=(30,20))
    plt.imshow(grid.detach().cpu().numpy().transpose((1,2,0)))

    if not out_file is None:
        print('try save ', out_file)
        plt.savefig(out_file)

    plt.show()

def show_light_batch(light_batch):
    light_batch = convert_Relight_latent_light(light_batch)
    show_batch(light_batch)
    
def save_loss(figure_fname, train_loss, valid_loss):
    plt.plot(train_loss)
    plt.plot(valid_loss)
    plt.legend(['train_loss', 'valid_loss'])
    plt.savefig(figure_fname)

def save_model(output_folder, model, optimizer, epoch, best_loss, fname, hist_train_loss, hist_valid_loss, hist_lr, params):
    """ Save current best model into some folder """
    create_folder(output_folder)

    # cur_time_stamp = get_cur_time_stamp()
    # output_fname = os.path.join(output_folder, exp_name + '_' + cur_time_stamp + ".pt")
    output_fname = os.path.join(output_folder, fname)
    tmp_model = model
    if params.multi_gpu and hasattr(tmp_model, 'module'):
        tmp_model = model.module

    torch.save({
        'epoch': epoch,
        'best_loss': best_loss,
        'model_state_dict': tmp_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'hist_train_loss': hist_train_loss,
        'hist_valid_loss': hist_valid_loss,
        'hist_lr':hist_lr,
        'params':str(params)
        }, output_fname)
    return output_fname
    
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def set_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr