| import torch |
| import numpy as np |
| import torch.nn.functional as F |
| import cv2 |
| def padding_4x(seq_noise): |
| sh_im = seq_noise.size() |
| expanded_h = sh_im[-2]%16 |
|
|
| if expanded_h: |
| expanded_h = 16-expanded_h |
| expanded_w = sh_im[-1]%16 |
| if expanded_w: |
| expanded_w = 16-expanded_w |
|
|
| padexp = (0, expanded_w, 0, expanded_h) |
| seq_noise = F.pad(input=seq_noise, pad=padexp, mode='reflect') |
| return seq_noise, expanded_h, expanded_w |
|
|
| def depadding(seq_denoise,expanded_h, expanded_w): |
| if expanded_h: |
| seq_denoise = seq_denoise[:, :, :-expanded_h, :] |
| if expanded_w: |
| seq_denoise = seq_denoise[:, :, :, :-expanded_w] |
| return seq_denoise |
| def chunkV3(net, input_data, option, patch_h = 516, patch_w = 516, patch_h_overlap = 16, patch_w_overlap = 16): |
| |
|
|
| |
| |
|
|
| shape_list = input_data.shape |
|
|
| if option == 'image': |
| B, C, H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3] |
| if option == 'RViDeformer': |
| B, F, C, H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3], shape_list[4] |
| if option == 'three2one': |
| B, FC , H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3] |
| |
| if option == 'image': |
| test_result = torch.zeros_like(input_data).cpu() |
| if option == 'RViDeformer': |
| test_result = torch.zeros_like(input_data).cpu() |
| if option == 'three2one': |
| test_result = torch.zeros((B, 4 , H, W)).cpu() |
|
|
|
|
| |
| h_index = 1 |
| while (patch_h*h_index-patch_h_overlap*(h_index-1)) < H: |
| if option == 'image': |
| test_horizontal_result = torch.zeros((B,C,patch_h,W)).cpu() |
| if option == 'RViDeformer': |
| test_horizontal_result = torch.zeros((B, F, C, patch_h, W)).cpu() |
| if option == 'three2one': |
| test_horizontal_result = torch.zeros((B, 4, patch_h, W)).cpu() |
|
|
| h_begin = patch_h*(h_index-1)-patch_h_overlap*(h_index-1) |
| h_end = patch_h*h_index-patch_h_overlap*(h_index-1) |
| w_index = 1 |
| while (patch_w*w_index-patch_w_overlap*(w_index-1)) < W: |
| w_begin = patch_w*(w_index-1)-patch_w_overlap*(w_index-1) |
| w_end = patch_w*w_index-patch_w_overlap*(w_index-1) |
| test_patch = input_data[...,h_begin:h_end,w_begin:w_end] |
|
|
| with torch.no_grad(): |
| test_patch_result = net(test_patch).detach().cpu() |
|
|
| if w_index == 1: |
| test_horizontal_result[...,w_begin:w_end] = test_patch_result |
| else: |
| for i in range(patch_w_overlap): |
| test_horizontal_result[...,w_begin+i] = test_horizontal_result[...,w_begin+i]*(patch_w_overlap-1-i)/(patch_w_overlap-1)+test_patch_result[...,i]*i/(patch_w_overlap-1) |
| test_horizontal_result[...,w_begin+patch_w_overlap:w_end] = test_patch_result[...,patch_w_overlap:] |
| w_index += 1 |
| |
| test_patch = input_data[...,h_begin:h_end,-patch_w:] |
|
|
| with torch.no_grad(): |
| test_patch_result = net(test_patch).detach().cpu() |
| last_range = w_end-(W-patch_w) |
|
|
| for i in range(last_range): |
| test_horizontal_result[...,W-patch_w+i] = test_horizontal_result[...,W-patch_w+i]*(last_range-1-i)/(last_range-1)+test_patch_result[...,i]*i/(last_range-1) |
| test_horizontal_result[...,w_end:] = test_patch_result[...,last_range:] |
|
|
| if h_index == 1: |
| test_result[...,h_begin:h_end,:] = test_horizontal_result |
| else: |
| for i in range(patch_h_overlap): |
| test_result[...,h_begin+i,:] = test_result[...,h_begin+i,:]*(patch_h_overlap-1-i)/(patch_h_overlap-1)+test_horizontal_result[...,i,:]*i/(patch_h_overlap-1) |
| test_result[...,h_begin+patch_h_overlap:h_end,:] = test_horizontal_result[...,patch_h_overlap:,:] |
| h_index += 1 |
|
|
| if option == 'image': |
| test_horizontal_result = torch.zeros((B,C,patch_h,W)).cpu() |
| if option == 'RViDeformer': |
| test_horizontal_result = torch.zeros((B, F, C, patch_h, W)).cpu() |
| if option == 'three2one': |
| test_horizontal_result = torch.zeros((B, 4, patch_h, W)).cpu() |
| |
| w_index = 1 |
| while (patch_w*w_index-patch_w_overlap*(w_index-1)) < W: |
| w_begin = patch_w*(w_index-1)-patch_w_overlap*(w_index-1) |
| w_end = patch_w*w_index-patch_w_overlap*(w_index-1) |
| test_patch = input_data[...,-patch_h:,w_begin:w_end] |
| |
| with torch.no_grad(): |
| test_patch_result = net(test_patch).detach().cpu() |
|
|
| if w_index == 1: |
| test_horizontal_result[...,w_begin:w_end] = test_patch_result |
| else: |
| for i in range(patch_w_overlap): |
| test_horizontal_result[...,w_begin+i] = test_horizontal_result[...,w_begin+i]*(patch_w_overlap-1-i)/(patch_w_overlap-1)+test_patch_result[...,i]*i/(patch_w_overlap-1) |
| test_horizontal_result[...,w_begin+patch_w_overlap:w_end] = test_patch_result[...,patch_w_overlap:] |
| w_index += 1 |
|
|
| test_patch = input_data[...,-patch_h:,-patch_w:] |
|
|
| with torch.no_grad(): |
| test_patch_result = net(test_patch).detach().cpu() |
| last_range = w_end-(W-patch_w) |
| for i in range(last_range): |
| test_horizontal_result[...,W-patch_w+i] = test_horizontal_result[...,W-patch_w+i]*(last_range-1-i)/(last_range-1)+test_patch_result[...,i]*i/(last_range-1) |
| test_horizontal_result[...,w_end:] = test_patch_result[...,last_range:] |
|
|
| last_last_range = h_end-(H-patch_h) |
| for i in range(last_last_range): |
| test_result[...,H-patch_w+i,:] = test_result[...,H-patch_w+i,:]*(last_last_range-1-i)/(last_last_range-1)+test_horizontal_result[...,i,:]*i/(last_last_range-1) |
| test_result[...,h_end:,:] = test_horizontal_result[...,last_last_range:,:] |
| |
| |
| |
|
|
| return test_result |
|
|
|
|
| def calculate_psnr(img, img2, input_order='HWC'): |
|
|
|
|
| assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') |
| if input_order not in ['HWC', 'CHW']: |
| raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') |
| |
| img = img.transpose(1, 2, 0) |
| img2 = img2.transpose(1, 2, 0) |
|
|
|
|
| img = img.astype(np.float64) |
| img2 = img2.astype(np.float64) |
|
|
| mse = np.mean((img - img2)**2) |
| if mse == 0: |
| return float('inf') |
| return 10. * np.log10(1. * 1. / mse) |
|
|
|
|
| def calculate_ssim(img, img2, input_order='HWC'): |
|
|
|
|
| assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') |
| if input_order not in ['HWC', 'CHW']: |
| raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') |
|
|
|
|
| img = img.transpose(1, 2, 0) |
| img2 = img2.transpose(1, 2, 0) |
|
|
|
|
| img = img.astype(np.float64) |
| img2 = img2.astype(np.float64) |
|
|
| ssims = [] |
| for i in range(img.shape[2]): |
| ssims.append(_ssim(img[..., i], img2[..., i])) |
| return np.array(ssims).mean() |
| |
| def _ssim(img, img2): |
| """Calculate SSIM (structural similarity) for one channel images. |
| |
| It is called by func:`calculate_ssim`. |
| |
| Args: |
| img (ndarray): Images with range [0, 255] with order 'HWC'. |
| img2 (ndarray): Images with range [0, 255] with order 'HWC'. |
| |
| Returns: |
| float: SSIM result. |
| """ |
|
|
| c1 = (0.01 * 1)**2 |
| c2 = (0.03 * 1)**2 |
| kernel = cv2.getGaussianKernel(11, 1.5) |
| window = np.outer(kernel, kernel.transpose()) |
|
|
| mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] |
| mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] |
| mu1_sq = mu1**2 |
| mu2_sq = mu2**2 |
| mu1_mu2 = mu1 * mu2 |
| sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq |
| sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq |
| sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 |
|
|
| ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) |
| return ssim_map.mean() |