| | import numpy as np
|
| | import os, sys
|
| | from tqdm import tqdm
|
| | from options.options import parse
|
| | import argparse
|
| |
|
| | parser = argparse.ArgumentParser(description="Script for testing")
|
| | parser.add_argument('-p', '--config', type=str, default='./options/test/LOLBlur.yml', help = 'Config file of testing')
|
| | args = parser.parse_args()
|
| |
|
| |
|
| |
|
| | path_options = args.config
|
| | opt = parse(path_options)
|
| | os.environ["CUDA_VISIBLE_DEVICES"]= "0"
|
| |
|
| |
|
| | import torch
|
| | import torch.optim
|
| | import torch.multiprocessing as mp
|
| | import torch.distributed as dist
|
| |
|
| | from data.dataset_reader.datapipeline import *
|
| | from archs import *
|
| | from losses import *
|
| | from data import *
|
| | from utils.utils import create_path_models
|
| | from utils.test_utils import *
|
| | from ptflops import get_model_complexity_info
|
| |
|
| |
|
| | PATH_MODEL= create_path_models(opt['save'])
|
| |
|
| |
|
| | def load_model(model, path_weights):
|
| |
|
| | map_location = 'cpu'
|
| | checkpoints = torch.load(path_weights, map_location=map_location, weights_only=False)
|
| |
|
| |
|
| | weights = checkpoints['params']
|
| | weights = {'module.' + key: value for key, value in weights.items()}
|
| |
|
| | macs, params = get_model_complexity_info(model, (3, 256, 256), print_per_layer_stat=False, verbose=False)
|
| | print('Network complexity: ' ,macs, params)
|
| |
|
| | model.load_state_dict(weights)
|
| | print('Loaded weights correctly')
|
| |
|
| | return model
|
| |
|
| | def run_evaluation(rank, world_size):
|
| |
|
| | setup(rank, world_size=world_size)
|
| |
|
| | test_loader, _ = create_test_data(rank, world_size=world_size, opt = opt['datasets'])
|
| |
|
| | model, _, _ = create_model(opt['network'], rank=rank)
|
| |
|
| | model = load_model(model, opt['save']['path'])
|
| | metrics_eval = {}
|
| |
|
| |
|
| | dist.barrier()
|
| |
|
| | model.eval()
|
| | metrics_eval, _ = eval_model(model, test_loader, metrics_eval, rank=rank, world_size=world_size, eta = True)
|
| |
|
| | dist.barrier()
|
| |
|
| | if rank==0:
|
| | if type(next(iter(metrics_eval.values()))) == dict:
|
| | for key, metric_eval in metrics_eval.items():
|
| | print(f" \t {key} --- PSNR: {metric_eval['valid_psnr']}, SSIM: {metric_eval['valid_ssim']}, LPIPS: {metric_eval['valid_lpips']}")
|
| | else:
|
| | print(f" \t {opt['datasets']['name']} --- PSNR: {metrics_eval['valid_psnr']}, SSIM: {metrics_eval['valid_ssim']}, LPIPS: {metrics_eval['valid_lpips']}")
|
| | cleanup()
|
| |
|
| | def main():
|
| | world_size = 1
|
| | mp.spawn(run_evaluation, args =(world_size,), nprocs=world_size, join=True)
|
| |
|
| | if __name__ == '__main__':
|
| | main()
|
| |
|