File size: 7,318 Bytes
2ecc7ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | import torch
import torch.distributed as dist
import sys, os
from lpips import LPIPS
import numpy as np
sys.path.append('../losses')
sys.path.append('../data/datasets/datapipeline')
from losses import *
from tqdm import tqdm
calc_SSIM = SSIM(data_range=1.)
#---------- Set of functions to work with DDP
def setup(rank, world_size, Master_port = '12355'):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = Master_port
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def reduce_tensor(tensor, world_size):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= world_size
return rt
def save_model(model, path):
if dist.get_rank() == 0:
torch.save(model.state_dict(), path)
def shuffle_sampler(samplers, epoch):
'''
A function that shuffles all the Distributed samplers in the loaders.
'''
if not samplers: # if they are none
return
for sampler in samplers:
sampler.set_epoch(epoch)
def eval_one_loader(model, test_loader, metrics, rank=0, world_size = 1, eta = False):
calc_LPIPS = LPIPS(net = 'vgg', verbose=False).to(rank)
mean_metrics = {'valid_psnr':[], 'valid_ssim':[], 'valid_lpips':[]}
if eta: pbar = tqdm(total = int(len(test_loader)))
with torch.no_grad():
# Now we need to go over the test_loader and evaluate the results of the epoch
for high_batch_valid, low_batch_valid in test_loader:
high_batch_valid = high_batch_valid.to(rank)
low_batch_valid = low_batch_valid.to(rank)
enhanced_batch_valid = model(low_batch_valid)
# loss
valid_loss_batch = torch.mean((high_batch_valid - enhanced_batch_valid)**2)
valid_ssim_batch = calc_SSIM(enhanced_batch_valid, high_batch_valid)
valid_lpips_batch = calc_LPIPS(enhanced_batch_valid, high_batch_valid)
valid_psnr_batch = 20 * torch.log10(1. / torch.sqrt(valid_loss_batch))
mean_metrics['valid_psnr'].append(valid_psnr_batch.item())
mean_metrics['valid_ssim'].append(valid_ssim_batch.item())
mean_metrics['valid_lpips'].append(torch.mean(valid_lpips_batch).item())
if eta: pbar.update(1)
valid_psnr_tensor = reduce_tensor(torch.tensor(np.mean(mean_metrics['valid_psnr'])).to(rank), world_size=world_size)
valid_ssim_tensor = reduce_tensor(torch.tensor(np.mean(mean_metrics['valid_ssim'])).to(rank),world_size=world_size)
valid_lpips_tensor = reduce_tensor(torch.tensor(np.mean(mean_metrics['valid_lpips'])).to(rank), world_size=world_size)
metrics['valid_psnr'] = valid_psnr_tensor.item()
metrics['valid_ssim'] = valid_ssim_tensor.item()
metrics['valid_lpips'] = valid_lpips_tensor.item()
imgs_dict = {'input':low_batch_valid[0], 'output':enhanced_batch_valid[0], 'gt':high_batch_valid[0]}
if eta: pbar.close()
return metrics, imgs_dict
def eval_model(model, test_loader, metrics, rank=None, world_size = 1, eta = False):
'''
This function runs over the multiple test loaders and returns the whole metrics.
'''
#first you need to assert that test_loader is a dictionary
if type(test_loader) != dict:
test_loader = {'data': test_loader}
if len(test_loader) > 1:
all_metrics = {}
all_imgs_dict = {}
for key, loader in test_loader.items():
all_metrics[f'{key}'] = {}
metrics, imgs_dict = eval_one_loader(model, loader['loader'], all_metrics[f'{key}'], rank=rank, world_size=world_size, eta=eta)
all_metrics[f'{key}'] = metrics
all_imgs_dict[f'{key}'] = imgs_dict
return all_metrics, all_imgs_dict
else:
metrics, imgs_dict = eval_one_loader(model, test_loader['data'], metrics, rank=rank, world_size=world_size, eta=eta)
return metrics, imgs_dict
def eval_one_loader_two_models(model1, model2, test_loader, metrics, devices = ['cuda:0', 'cuda:1'], eta = False):
calc_LPIPS = LPIPS(net = 'vgg', verbose=False).to(devices[0])
mean_metrics = {'valid_psnr':[], 'valid_ssim':[], 'valid_lpips':[]}
if eta: pbar = tqdm(total = int(len(test_loader)))
with torch.no_grad():
# Now we need to go over the test_loader and evaluate the results of the epoch
for high_batch_valid, low_batch_valid in test_loader:
high_batch_valid = high_batch_valid.to(devices[0])
low_batch_valid = low_batch_valid.to(devices[0])
enhanced_batch_valid = model1(low_batch_valid)
enhanced_batch_valid = torch.clamp(enhanced_batch_valid, 0., 1.)
enhanced_batch_valid = model2(enhanced_batch_valid.to(devices[1]))
# loss
enhanced_batch_valid = enhanced_batch_valid.to(devices[0])
valid_loss_batch = torch.mean((high_batch_valid - enhanced_batch_valid)**2)
valid_ssim_batch = calc_SSIM(enhanced_batch_valid, high_batch_valid)
valid_lpips_batch = calc_LPIPS(enhanced_batch_valid, high_batch_valid)
valid_psnr_batch = 20 * torch.log10(1. / torch.sqrt(valid_loss_batch))
# print(valid_loss_batch)
mean_metrics['valid_psnr'].append(valid_psnr_batch.item())
mean_metrics['valid_ssim'].append(valid_ssim_batch.item())
mean_metrics['valid_lpips'].append(torch.mean(valid_lpips_batch).item())
# print(valid_psnr_batch.item())
if eta: pbar.update(1)
print(mean_metrics['valid_psnr'])
valid_psnr_tensor = np.mean(mean_metrics['valid_psnr'])
valid_ssim_tensor = np.mean(mean_metrics['valid_ssim'])
valid_lpips_tensor = np.mean(mean_metrics['valid_lpips'])
metrics['valid_psnr'] = valid_psnr_tensor.item()
metrics['valid_ssim'] = valid_ssim_tensor.item()
metrics['valid_lpips'] = valid_lpips_tensor.item()
imgs_dict = {'input':low_batch_valid[0], 'output':enhanced_batch_valid[0], 'gt':high_batch_valid[0]}
if eta: pbar.close()
return metrics, imgs_dict
def eval_model_two_models(model1, model2, test_loader, metrics, devices=['cuda:0', 'cuda:1'], eta = False):
'''
This function runs over the multiple test loaders and returns the whole metrics.
'''
#first you need to assert that test_loader is a dictionary
if type(test_loader) != dict:
test_loader = {'data': test_loader}
if len(test_loader) > 1:
all_metrics = {}
all_imgs_dict = {}
for key, loader in test_loader.items():
all_metrics[f'{key}'] = {}
metrics, imgs_dict = eval_one_loader_two_models(model1, model2, loader['loader'], all_metrics[f'{key}'], devices = devices, eta=eta)
all_metrics[f'{key}'] = metrics
all_imgs_dict[f'{key}'] = imgs_dict
return all_metrics, all_imgs_dict
else:
metrics, imgs_dict = eval_one_loader_two_models(model1, model2, test_loader['data'], metrics, devices = devices, eta=eta)
return metrics, imgs_dict |