| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import os |
| |
|
| | import clip |
| | import lpips |
| | import numpy as np |
| | import torch |
| | from DISTS_pytorch import DISTS |
| | from PIL import Image |
| | from skimage.metrics import peak_signal_noise_ratio as psnr |
| | from skimage.metrics import structural_similarity as ssim |
| | from torchvision import transforms |
| | from tqdm import tqdm |
| |
|
| | IMG_SIZE = (256, 256) |
| | DIR_GT = "ground-truths" |
| | DIR_PRED = "predictions" |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--results_dir", type=str, required=True, help="The directory of the results" |
| | ) |
| | parser.add_argument( |
| | "--sample_n", |
| | type=int, |
| | default=-1, |
| | help="Randomly sample the number of frames to evaluate. *Use for DEBUG purpose only*", |
| | ) |
| |
|
| |
|
| | args = parser.parse_args() |
| |
|
| | res_dir = args.results_dir |
| |
|
| | dir_gt = os.path.join(args.results_dir, DIR_GT) |
| | dir_pred = os.path.join(args.results_dir, DIR_PRED) |
| |
|
| | img_names = os.listdir(dir_gt) |
| | print(f"number of images: {len(img_names)}") |
| |
|
| | if args.sample_n > 0: |
| | img_names = np.random.choice(img_names, args.sample_n, replace=False) |
| | print(f"sample {args.sample_n} imgs for evaluation") |
| |
|
| | ssims = [] |
| | psnrs = [] |
| | distss = [] |
| | lpipss = [] |
| | clip_scores = [] |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | dists_fn = DISTS().to(device) |
| | lpips_fn = lpips.LPIPS(net="alex").to(device) |
| |
|
| | transform = transforms.Compose( |
| | [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))] |
| | ) |
| |
|
| | np.random.shuffle(img_names) |
| |
|
| | clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) |
| |
|
| |
|
| | def cal_clip_score(img1: np.ndarray, img2: np.ndarray): |
| | """calculate clip score. |
| | |
| | Args: |
| | img1 (np.ndarray): The first image. Shape: [H,W,C]. dtype: uint8. |
| | img2 (np.ndarray): The second image. Shape: [H,W,C]. dtype: uint8. |
| | |
| | Returns: TODO |
| | |
| | """ |
| | img1 = clip_preprocess(Image.fromarray(img1)).unsqueeze(0).to(device) |
| | img2 = clip_preprocess(Image.fromarray(img2)).unsqueeze(0).to(device) |
| | img1_features = clip_model.encode_image(img1) |
| | img2_features = clip_model.encode_image(img2) |
| | img1_features = img1_features / img1_features.norm(dim=1, keepdim=True).to(torch.float32) |
| | img2_features = img2_features / img2_features.norm(dim=1, keepdim=True).to(torch.float32) |
| | logit_scale = clip_model.logit_scale.exp() |
| | score = logit_scale * (img1_features * img2_features).sum() |
| | return score |
| |
|
| |
|
| | def load_img(img_path: str): |
| | """load image to numpy array |
| | |
| | Args: |
| | img_path (str): path to image. |
| | |
| | Returns: np.ndarray | None. dtype: uint8. |
| | return None if file not exist or occurred some errors during loading. |
| | |
| | """ |
| | try: |
| | if not os.path.isfile(img_path): |
| | print(f"file not existed for image: {img_path}") |
| | return None |
| |
|
| | img_pil = Image.open(img_path).convert("RGB") |
| | img_pil = img_pil.resize(IMG_SIZE) |
| | return np.array(img_pil) |
| | except Exception as e: |
| | print(f"Exception while loading image: {img_path}: {e}") |
| | return None |
| |
|
| |
|
| | for i, img_name in enumerate(tqdm(img_names)): |
| | if not img_name.endswith(".png"): |
| | continue |
| |
|
| | img_gt = load_img(os.path.join(dir_gt, img_name)) |
| | if img_gt is None: |
| | continue |
| |
|
| | img_pred = load_img(os.path.join(dir_pred, img_name)) |
| |
|
| | if img_pred is None: |
| | ssim_value, psnr_value, clip_score, dists_value, lpips_value = 0, 0, 0, 1, 1 |
| | else: |
| | |
| | ssim_value = ssim(img_gt, img_pred, channel_axis=2) |
| | psnr_value = psnr(img_gt, img_pred) |
| |
|
| | with torch.no_grad(): |
| | |
| | clip_score = cal_clip_score(img_gt, img_pred).item() |
| |
|
| | |
| | img_gt_norm = transform(img_gt).unsqueeze(0).to(device) |
| | img_pred_norm = transform(img_pred).unsqueeze(0).to(device) |
| | dists_value = dists_fn(img_gt_norm, img_pred_norm).item() |
| | lpips_value = lpips_fn(img_gt_norm, img_pred_norm).item() |
| |
|
| | ssims.append(ssim_value) |
| | psnrs.append(psnr_value) |
| | distss.append(dists_value) |
| | lpipss.append(lpips_value) |
| | clip_scores.append(clip_score) |
| |
|
| | print("ssim={}".format(sum(ssims) / len(ssims))) |
| | print("psnr={}".format(sum(psnrs) / len(psnrs))) |
| | print("dists={}".format(sum(distss) / len(distss))) |
| | print("lpips={}".format(sum(lpipss) / len(lpipss))) |
| | print("clip_score={}".format(sum(clip_scores) / len(clip_scores))) |
| |
|