YuqianFu's picture
Upload folder using huggingface_hub
944cdc2 verified
# pip install git+https://github.com/openai/CLIP.git
# pip install lpips
# pip install dists-pytorch
# pip install scikit-image
import argparse
import os
import clip
import lpips
import numpy as np
import torch
from DISTS_pytorch import DISTS
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from torchvision import transforms
from tqdm import tqdm
IMG_SIZE = (256, 256) # resize image to this size for evaluation.
DIR_GT = "ground-truths"
DIR_PRED = "predictions"
parser = argparse.ArgumentParser()
parser.add_argument(
"--results_dir", type=str, required=True, help="The directory of the results"
)
parser.add_argument(
"--sample_n",
type=int,
default=-1,
help="Randomly sample the number of frames to evaluate. *Use for DEBUG purpose only*",
)
args = parser.parse_args()
res_dir = args.results_dir
dir_gt = os.path.join(args.results_dir, DIR_GT)
dir_pred = os.path.join(args.results_dir, DIR_PRED)
img_names = os.listdir(dir_gt)
print(f"number of images: {len(img_names)}")
if args.sample_n > 0:
img_names = np.random.choice(img_names, args.sample_n, replace=False)
print(f"sample {args.sample_n} imgs for evaluation")
ssims = []
psnrs = []
distss = []
lpipss = []
clip_scores = []
device = "cuda" if torch.cuda.is_available() else "cpu"
dists_fn = DISTS().to(device)
lpips_fn = lpips.LPIPS(net="alex").to(device)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
)
np.random.shuffle(img_names)
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
def cal_clip_score(img1: np.ndarray, img2: np.ndarray):
"""calculate clip score.
Args:
img1 (np.ndarray): The first image. Shape: [H,W,C]. dtype: uint8.
img2 (np.ndarray): The second image. Shape: [H,W,C]. dtype: uint8.
Returns: TODO
"""
img1 = clip_preprocess(Image.fromarray(img1)).unsqueeze(0).to(device)
img2 = clip_preprocess(Image.fromarray(img2)).unsqueeze(0).to(device)
img1_features = clip_model.encode_image(img1)
img2_features = clip_model.encode_image(img2)
img1_features = img1_features / img1_features.norm(dim=1, keepdim=True).to(torch.float32)
img2_features = img2_features / img2_features.norm(dim=1, keepdim=True).to(torch.float32)
logit_scale = clip_model.logit_scale.exp()
score = logit_scale * (img1_features * img2_features).sum()
return score
def load_img(img_path: str):
"""load image to numpy array
Args:
img_path (str): path to image.
Returns: np.ndarray | None. dtype: uint8.
return None if file not exist or occurred some errors during loading.
"""
try:
if not os.path.isfile(img_path):
print(f"file not existed for image: {img_path}")
return None
img_pil = Image.open(img_path).convert("RGB")
img_pil = img_pil.resize(IMG_SIZE)
return np.array(img_pil)
except Exception as e:
print(f"Exception while loading image: {img_path}: {e}")
return None
for i, img_name in enumerate(tqdm(img_names)):
if not img_name.endswith(".png"):
continue
img_gt = load_img(os.path.join(dir_gt, img_name))
if img_gt is None: # skip if erros with GT image.
continue
img_pred = load_img(os.path.join(dir_pred, img_name))
if img_pred is None: # set values if missing prediction.
ssim_value, psnr_value, clip_score, dists_value, lpips_value = 0, 0, 0, 1, 1
else:
# SSIM and PSNR
ssim_value = ssim(img_gt, img_pred, channel_axis=2)
psnr_value = psnr(img_gt, img_pred)
with torch.no_grad():
# clip score
clip_score = cal_clip_score(img_gt, img_pred).item()
# DISTS and LPIPS
img_gt_norm = transform(img_gt).unsqueeze(0).to(device)
img_pred_norm = transform(img_pred).unsqueeze(0).to(device)
dists_value = dists_fn(img_gt_norm, img_pred_norm).item()
lpips_value = lpips_fn(img_gt_norm, img_pred_norm).item()
ssims.append(ssim_value)
psnrs.append(psnr_value)
distss.append(dists_value)
lpipss.append(lpips_value)
clip_scores.append(clip_score)
print("ssim={}".format(sum(ssims) / len(ssims)))
print("psnr={}".format(sum(psnrs) / len(psnrs)))
print("dists={}".format(sum(distss) / len(distss)))
print("lpips={}".format(sum(lpipss) / len(lpipss)))
print("clip_score={}".format(sum(clip_scores) / len(clip_scores)))