# -*- coding: utf-8 -*- import cv2 import numpy as np import importlib import os import argparse from PIL import Image import torch from torch.utils.data import DataLoader from core.dataset import TestDataset from core.metrics import calc_psnr_and_ssim, calculate_i3d_activations, calculate_vfid, init_i3d_model # global variables w, h = 432, 240 ref_length = 10 neighbor_stride = 5 default_fps = 24 # sample reference frames from the whole video def get_ref_index(neighbor_ids, length): ref_index = [] for i in range(0, length, ref_length): if i not in neighbor_ids: ref_index.append(i) return ref_index def main_worker(args): args.size = (w, h) # set up datasets and data loader assert (args.dataset == 'davis') or args.dataset == 'youtube-vos', \ f"{args.dataset} dataset is not supported" test_dataset = TestDataset(args) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers) # set up models device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net = importlib.import_module('model.' + args.model) model = net.InpaintGenerator().to(device) data = torch.load(args.ckpt, map_location=device) model.load_state_dict(data) print(f'Loading from: {args.ckpt}') model.eval() total_frame_psnr = [] total_frame_ssim = [] output_i3d_activations = [] real_i3d_activations = [] print('Start evaluation...') # create results directory result_path = os.path.join('results', f'{args.model}_{args.dataset}') if not os.path.exists(result_path): os.makedirs(result_path) eval_summary = open( os.path.join(result_path, f"{args.model}_{args.dataset}_metrics.txt"), "w") i3d_model = init_i3d_model() for index, items in enumerate(test_loader): frames, masks, video_name, frames_PIL = items video_length = frames.size(1) frames, masks = frames.to(device), masks.to(device) ori_frames = frames_PIL ori_frames = [ ori_frames[i].squeeze().cpu().numpy() for i in range(video_length) ] comp_frames = [None] * video_length # complete holes by our model for f in range(0, video_length, neighbor_stride): neighbor_ids = [ i for i in range(max(0, f - neighbor_stride), min(video_length, f + neighbor_stride + 1)) ] ref_ids = get_ref_index(neighbor_ids, video_length) selected_imgs = frames[:1, neighbor_ids + ref_ids, :, :, :] selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :] with torch.no_grad(): masked_frames = selected_imgs * (1 - selected_masks) pred_img, _ = model(masked_frames, len(neighbor_ids)) pred_img = (pred_img + 1) / 2 pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 binary_masks = masks[0, neighbor_ids, :, :, :].cpu().permute( 0, 2, 3, 1).numpy().astype(np.uint8) for i in range(len(neighbor_ids)): idx = neighbor_ids[i] img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \ + ori_frames[idx] * (1 - binary_masks[i]) if comp_frames[idx] is None: comp_frames[idx] = img else: comp_frames[idx] = comp_frames[idx].astype( np.float32) * 0.5 + img.astype(np.float32) * 0.5 # calculate metrics cur_video_psnr = [] cur_video_ssim = [] comp_PIL = [] # to calculate VFID frames_PIL = [] for ori, comp in zip(ori_frames, comp_frames): psnr, ssim = calc_psnr_and_ssim(ori, comp) cur_video_psnr.append(psnr) cur_video_ssim.append(ssim) total_frame_psnr.append(psnr) total_frame_ssim.append(ssim) frames_PIL.append(Image.fromarray(ori.astype(np.uint8))) comp_PIL.append(Image.fromarray(comp.astype(np.uint8))) cur_psnr = sum(cur_video_psnr) / len(cur_video_psnr) cur_ssim = sum(cur_video_ssim) / len(cur_video_ssim) # saving i3d activations frames_i3d, comp_i3d = calculate_i3d_activations(frames_PIL, comp_PIL, i3d_model, device=device) real_i3d_activations.append(frames_i3d) output_i3d_activations.append(comp_i3d) print( f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f}' ) eval_summary.write( f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f}\n' ) # saving images for evaluating warpping errors if args.save_results: save_frame_path = os.path.join(result_path, video_name[0]) os.makedirs(save_frame_path, exist_ok=False) for i, frame in enumerate(comp_frames): cv2.imwrite( os.path.join(save_frame_path, str(i).zfill(5) + '.png'), cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR)) avg_frame_psnr = sum(total_frame_psnr) / len(total_frame_psnr) avg_frame_ssim = sum(total_frame_ssim) / len(total_frame_ssim) fid_score = calculate_vfid(real_i3d_activations, output_i3d_activations) print('Finish evaluation... Average Frame PSNR/SSIM/VFID: ' f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f}') eval_summary.write( 'Finish evaluation... Average Frame PSNR/SSIM/VFID: ' f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f}') eval_summary.close() if __name__ == '__main__': parser = argparse.ArgumentParser(description='E2FGVI') parser.add_argument('--dataset', choices=['davis', 'youtube-vos'], type=str) parser.add_argument('--data_root', type=str, required=True) parser.add_argument('--model', choices=['e2fgvi', 'e2fgvi_hq'], type=str) parser.add_argument('--ckpt', type=str, required=True) parser.add_argument('--save_results', action='store_true', default=False) parser.add_argument('--num_workers', default=4, type=int) args = parser.parse_args() main_worker(args)