karthikeya1212's picture
Upload 115 files
cda88e0 verified
#import matplotlib.pyplot as plt
import os
from torchvision import transforms, utils
import torch
#import matplotlib.pyplot as plt
import numpy as np
from utils.utils_file import get_cur_time_stamp, create_folder
def compute_differentiable_params(net):
return sum(p.numel() for p in net.parameters() if p.requires_grad)
def convert_Relight_latent_light(latent_feature):
""" Convert n x 6 x 16 x 16 -> n x 3 x 16 x 32 """
# torch image: C X H X W
batch_size, C, H, W = latent_feature.size()
latent_feature = torch.reshape(latent_feature, (batch_size, 3, 16, 32)) # make sure it is right
# print(latent_feature.size())
return latent_feature
def show_batch(sample_batch, out_file=None):
grid = utils.make_grid(sample_batch)
plt.figure(figsize=(30,20))
plt.imshow(grid.detach().cpu().numpy().transpose((1,2,0)))
if not out_file is None:
print('try save ', out_file)
plt.savefig(out_file)
plt.show()
def show_light_batch(light_batch):
light_batch = convert_Relight_latent_light(light_batch)
show_batch(light_batch)
def save_loss(figure_fname, train_loss, valid_loss):
plt.plot(train_loss)
plt.plot(valid_loss)
plt.legend(['train_loss', 'valid_loss'])
plt.savefig(figure_fname)
def save_model(output_folder, model, optimizer, epoch, best_loss, fname, hist_train_loss, hist_valid_loss, hist_lr, params):
""" Save current best model into some folder """
create_folder(output_folder)
# cur_time_stamp = get_cur_time_stamp()
# output_fname = os.path.join(output_folder, exp_name + '_' + cur_time_stamp + ".pt")
output_fname = os.path.join(output_folder, fname)
tmp_model = model
if params.multi_gpu and hasattr(tmp_model, 'module'):
tmp_model = model.module
torch.save({
'epoch': epoch,
'best_loss': best_loss,
'model_state_dict': tmp_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'hist_train_loss': hist_train_loss,
'hist_valid_loss': hist_valid_loss,
'hist_lr':hist_lr,
'params':str(params)
}, output_fname)
return output_fname
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def set_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr