Spaces:
Sleeping
Sleeping
| import os | |
| import gc | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from metrics import pt_psnr, calculate_ssim, calculate_psnr | |
| from pytorch_msssim import ssim | |
| from utils import save_rgb | |
| def test_model (model, language_model, lm_head, testsets, device, promptify, savepath="results/"): | |
| model.eval() | |
| if language_model: | |
| language_model.eval() | |
| lm_head.eval() | |
| DEG_ACC = [] | |
| derain_datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800'] | |
| with torch.no_grad(): | |
| for testset in testsets: | |
| if savepath: | |
| dt_results_path = os.path.join(savepath, testset.name) | |
| if not os.path.exists(dt_results_path): | |
| os.mkdir(dt_results_path) | |
| print (">>> Eval on", testset.name, testset.degradation, testset.deg_class) | |
| testset_name = testset.name | |
| test_dataloader = DataLoader(testset, batch_size=1, num_workers=4, drop_last=True, shuffle=False) | |
| psnr_dataset = [] | |
| ssim_dataset = [] | |
| psnr_noisy = [] | |
| use_y_channel= False | |
| if testset.name in derain_datasets: | |
| use_y_channel = True | |
| psnr_y_dataset = [] | |
| ssim_y_dataset = [] | |
| for idx, batch in enumerate(test_dataloader): | |
| x = batch[0].to(device) # HQ image | |
| y = batch[1].to(device) # LQ image | |
| f = batch[2][0] # filename | |
| t = [promptify(testset.degradation) for _ in range(x.shape[0])] | |
| if language_model: | |
| if idx < 5: | |
| # print the input prompt for debugging | |
| print("\tInput prompt:", t) | |
| lm_embd = language_model(t) | |
| lm_embd = lm_embd.to(device) | |
| text_embd, deg_pred = lm_head (lm_embd) | |
| x_hat = model(y, text_embd) | |
| psnr_restore = torch.mean(pt_psnr(x, x_hat)) | |
| psnr_dataset.append(psnr_restore.item()) | |
| ssim_restore = ssim(x, x_hat, data_range=1., size_average=True) | |
| ssim_dataset.append(ssim_restore.item()) | |
| psnr_base = torch.mean(pt_psnr(x, y)) | |
| psnr_noisy.append(psnr_base.item()) | |
| if use_y_channel: | |
| _x_hat = np.clip(x_hat[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32) | |
| _x = np.clip(x[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32) | |
| _x_hat = (_x_hat*255).astype(np.uint8) | |
| _x = (_x*255).astype(np.uint8) | |
| psnr_y = calculate_psnr(_x, _x_hat, crop_border=0, input_order='HWC', test_y_channel=True) | |
| ssim_y = calculate_ssim(_x, _x_hat, crop_border=0, input_order='HWC', test_y_channel=True) | |
| psnr_y_dataset.append(psnr_y) | |
| ssim_y_dataset.append(ssim_y) | |
| ## SAVE RESULTS | |
| if savepath: | |
| restored_img = np.clip(x_hat[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32) | |
| img_name = f.split("/")[-1] | |
| save_rgb (restored_img, os.path.join(dt_results_path, img_name)) | |
| print(f"{testset_name}_base", np.mean(psnr_noisy), "Total images:", len(psnr_dataset)) | |
| print(f"{testset_name}_psnr", np.mean(psnr_dataset)) | |
| print(f"{testset_name}_ssim", np.mean(ssim_dataset)) | |
| if use_y_channel: | |
| print(f"{testset_name}_psnr-Y", np.mean(psnr_y_dataset), len(psnr_y_dataset)) | |
| print(f"{testset_name}_ssim-Y", np.mean(ssim_y_dataset)) | |
| print (); print (25 * "***") | |
| del test_dataloader,psnr_dataset, psnr_noisy; gc.collect() | |
| # END OF FUNCTION |