File size: 1,363 Bytes
4336727 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import torch
import numpy as np
import pickle
import cv2
def is_numpy_file(filename):
return any(filename.endswith(extension) for extension in [".npy"])
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".jpg"])
def is_png_file(filename):
return any(filename.endswith(extension) for extension in [".png"])
def is_pkl_file(filename):
return any(filename.endswith(extension) for extension in [".pkl"])
def load_pkl(filename_):
with open(filename_, 'rb') as f:
ret_dict = pickle.load(f)
return ret_dict
def save_dict(dict_, filename_):
with open(filename_, 'wb') as f:
pickle.dump(dict_, f)
def load_npy(filepath):
img = np.load(filepath)
return img
def load_img(filepath):
img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
img = img.astype(np.float32)
img = img/255.
return img
def save_img(filepath, img):
cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def myPSNR(tar_img, prd_img):
imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
rmse = (imdff**2).mean().sqrt()
ps = 20*torch.log10(1/rmse)
return ps
def batch_PSNR(img1, img2, data_range=None):
PSNR = []
for im1, im2 in zip(img1, img2):
psnr = myPSNR(im1, im2)
PSNR.append(psnr)
return sum(PSNR)/len(PSNR)
|