File size: 4,575 Bytes
944cdc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# pip install git+https://github.com/openai/CLIP.git
# pip install lpips
# pip install dists-pytorch
# pip install scikit-image

import argparse
import os

import clip
import lpips
import numpy as np
import torch
from DISTS_pytorch import DISTS
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from torchvision import transforms
from tqdm import tqdm

IMG_SIZE = (256, 256)  # resize image to this size for evaluation.
DIR_GT = "ground-truths"
DIR_PRED = "predictions"

parser = argparse.ArgumentParser()
parser.add_argument(
    "--results_dir", type=str, required=True, help="The directory of the results"
)
parser.add_argument(
    "--sample_n",
    type=int,
    default=-1,
    help="Randomly sample the number of frames to evaluate. *Use for DEBUG purpose only*",
)


args = parser.parse_args()

res_dir = args.results_dir

dir_gt = os.path.join(args.results_dir, DIR_GT)
dir_pred = os.path.join(args.results_dir, DIR_PRED)

img_names = os.listdir(dir_gt)
print(f"number of images: {len(img_names)}")

if args.sample_n > 0:
    img_names = np.random.choice(img_names, args.sample_n, replace=False)
    print(f"sample {args.sample_n} imgs for evaluation")

ssims = []
psnrs = []
distss = []
lpipss = []
clip_scores = []

device = "cuda" if torch.cuda.is_available() else "cpu"

dists_fn = DISTS().to(device)
lpips_fn = lpips.LPIPS(net="alex").to(device)

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
)

np.random.shuffle(img_names)

clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)


def cal_clip_score(img1: np.ndarray, img2: np.ndarray):
    """calculate clip score.

    Args:
        img1 (np.ndarray): The first image. Shape: [H,W,C]. dtype: uint8.
        img2 (np.ndarray): The second image. Shape: [H,W,C]. dtype: uint8.

    Returns: TODO

    """
    img1 = clip_preprocess(Image.fromarray(img1)).unsqueeze(0).to(device)
    img2 = clip_preprocess(Image.fromarray(img2)).unsqueeze(0).to(device)
    img1_features = clip_model.encode_image(img1)
    img2_features = clip_model.encode_image(img2)
    img1_features = img1_features / img1_features.norm(dim=1, keepdim=True).to(torch.float32)
    img2_features = img2_features / img2_features.norm(dim=1, keepdim=True).to(torch.float32)
    logit_scale = clip_model.logit_scale.exp()
    score = logit_scale * (img1_features * img2_features).sum()
    return score


def load_img(img_path: str):
    """load image to numpy array

    Args:
        img_path (str): path to image.

    Returns: np.ndarray | None. dtype: uint8.
        return None if file not exist or occurred some errors during loading.

    """
    try:
        if not os.path.isfile(img_path):
            print(f"file not existed for image: {img_path}")
            return None

        img_pil = Image.open(img_path).convert("RGB")
        img_pil = img_pil.resize(IMG_SIZE)
        return np.array(img_pil)
    except Exception as e:
        print(f"Exception while loading image: {img_path}: {e}")
        return None


for i, img_name in enumerate(tqdm(img_names)):
    if not img_name.endswith(".png"):
        continue

    img_gt = load_img(os.path.join(dir_gt, img_name))
    if img_gt is None:  # skip if erros with GT image.
        continue

    img_pred = load_img(os.path.join(dir_pred, img_name))

    if img_pred is None:  # set values if missing prediction.
        ssim_value, psnr_value, clip_score, dists_value, lpips_value = 0, 0, 0, 1, 1
    else:
        # SSIM and PSNR
        ssim_value = ssim(img_gt, img_pred, channel_axis=2)
        psnr_value = psnr(img_gt, img_pred)

        with torch.no_grad():
            # clip score
            clip_score = cal_clip_score(img_gt, img_pred).item()

            # DISTS and LPIPS
            img_gt_norm = transform(img_gt).unsqueeze(0).to(device)
            img_pred_norm = transform(img_pred).unsqueeze(0).to(device)
            dists_value = dists_fn(img_gt_norm, img_pred_norm).item()
            lpips_value = lpips_fn(img_gt_norm, img_pred_norm).item()

    ssims.append(ssim_value)
    psnrs.append(psnr_value)
    distss.append(dists_value)
    lpipss.append(lpips_value)
    clip_scores.append(clip_score)

print("ssim={}".format(sum(ssims) / len(ssims)))
print("psnr={}".format(sum(psnrs) / len(psnrs)))
print("dists={}".format(sum(distss) / len(distss)))
print("lpips={}".format(sum(lpipss) / len(lpipss)))
print("clip_score={}".format(sum(clip_scores) / len(clip_scores)))