File size: 4,575 Bytes
944cdc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | # pip install git+https://github.com/openai/CLIP.git
# pip install lpips
# pip install dists-pytorch
# pip install scikit-image
import argparse
import os
import clip
import lpips
import numpy as np
import torch
from DISTS_pytorch import DISTS
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from torchvision import transforms
from tqdm import tqdm
IMG_SIZE = (256, 256) # resize image to this size for evaluation.
DIR_GT = "ground-truths"
DIR_PRED = "predictions"
parser = argparse.ArgumentParser()
parser.add_argument(
"--results_dir", type=str, required=True, help="The directory of the results"
)
parser.add_argument(
"--sample_n",
type=int,
default=-1,
help="Randomly sample the number of frames to evaluate. *Use for DEBUG purpose only*",
)
args = parser.parse_args()
res_dir = args.results_dir
dir_gt = os.path.join(args.results_dir, DIR_GT)
dir_pred = os.path.join(args.results_dir, DIR_PRED)
img_names = os.listdir(dir_gt)
print(f"number of images: {len(img_names)}")
if args.sample_n > 0:
img_names = np.random.choice(img_names, args.sample_n, replace=False)
print(f"sample {args.sample_n} imgs for evaluation")
ssims = []
psnrs = []
distss = []
lpipss = []
clip_scores = []
device = "cuda" if torch.cuda.is_available() else "cpu"
dists_fn = DISTS().to(device)
lpips_fn = lpips.LPIPS(net="alex").to(device)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
)
np.random.shuffle(img_names)
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
def cal_clip_score(img1: np.ndarray, img2: np.ndarray):
"""calculate clip score.
Args:
img1 (np.ndarray): The first image. Shape: [H,W,C]. dtype: uint8.
img2 (np.ndarray): The second image. Shape: [H,W,C]. dtype: uint8.
Returns: TODO
"""
img1 = clip_preprocess(Image.fromarray(img1)).unsqueeze(0).to(device)
img2 = clip_preprocess(Image.fromarray(img2)).unsqueeze(0).to(device)
img1_features = clip_model.encode_image(img1)
img2_features = clip_model.encode_image(img2)
img1_features = img1_features / img1_features.norm(dim=1, keepdim=True).to(torch.float32)
img2_features = img2_features / img2_features.norm(dim=1, keepdim=True).to(torch.float32)
logit_scale = clip_model.logit_scale.exp()
score = logit_scale * (img1_features * img2_features).sum()
return score
def load_img(img_path: str):
"""load image to numpy array
Args:
img_path (str): path to image.
Returns: np.ndarray | None. dtype: uint8.
return None if file not exist or occurred some errors during loading.
"""
try:
if not os.path.isfile(img_path):
print(f"file not existed for image: {img_path}")
return None
img_pil = Image.open(img_path).convert("RGB")
img_pil = img_pil.resize(IMG_SIZE)
return np.array(img_pil)
except Exception as e:
print(f"Exception while loading image: {img_path}: {e}")
return None
for i, img_name in enumerate(tqdm(img_names)):
if not img_name.endswith(".png"):
continue
img_gt = load_img(os.path.join(dir_gt, img_name))
if img_gt is None: # skip if erros with GT image.
continue
img_pred = load_img(os.path.join(dir_pred, img_name))
if img_pred is None: # set values if missing prediction.
ssim_value, psnr_value, clip_score, dists_value, lpips_value = 0, 0, 0, 1, 1
else:
# SSIM and PSNR
ssim_value = ssim(img_gt, img_pred, channel_axis=2)
psnr_value = psnr(img_gt, img_pred)
with torch.no_grad():
# clip score
clip_score = cal_clip_score(img_gt, img_pred).item()
# DISTS and LPIPS
img_gt_norm = transform(img_gt).unsqueeze(0).to(device)
img_pred_norm = transform(img_pred).unsqueeze(0).to(device)
dists_value = dists_fn(img_gt_norm, img_pred_norm).item()
lpips_value = lpips_fn(img_gt_norm, img_pred_norm).item()
ssims.append(ssim_value)
psnrs.append(psnr_value)
distss.append(dists_value)
lpipss.append(lpips_value)
clip_scores.append(clip_score)
print("ssim={}".format(sum(ssims) / len(ssims)))
print("psnr={}".format(sum(psnrs) / len(psnrs)))
print("dists={}".format(sum(distss) / len(distss)))
print("lpips={}".format(sum(lpipss) / len(lpipss)))
print("clip_score={}".format(sum(clip_scores) / len(clip_scores)))
|