| import torch |
| import numpy as np |
| import pickle |
| import cv2 |
|
|
| def is_numpy_file(filename): |
| return any(filename.endswith(extension) for extension in [".npy"]) |
|
|
| def is_image_file(filename): |
| return any(filename.endswith(extension) for extension in [".jpg"]) |
|
|
| def is_png_file(filename): |
| return any(filename.endswith(extension) for extension in [".png"]) |
|
|
| def is_pkl_file(filename): |
| return any(filename.endswith(extension) for extension in [".pkl"]) |
|
|
| def load_pkl(filename_): |
| with open(filename_, 'rb') as f: |
| ret_dict = pickle.load(f) |
| return ret_dict |
|
|
| def save_dict(dict_, filename_): |
| with open(filename_, 'wb') as f: |
| pickle.dump(dict_, f) |
|
|
| def load_npy(filepath): |
| img = np.load(filepath) |
| return img |
|
|
| def load_img(filepath): |
| img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) |
| img = img.astype(np.float32) |
| img = img/255. |
| return img |
|
|
| def save_img(filepath, img): |
| cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) |
|
|
| def myPSNR(tar_img, prd_img): |
| imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) |
| rmse = (imdff**2).mean().sqrt() |
| ps = 20*torch.log10(1/rmse) |
| return ps |
|
|
| def batch_PSNR(img1, img2, data_range=None): |
| PSNR = [] |
| for im1, im2 in zip(img1, img2): |
| psnr = myPSNR(im1, im2) |
| PSNR.append(psnr) |
| return sum(PSNR)/len(PSNR) |
|
|