xfu314's picture
Add phantom project with submodules and dependencies
96da58e
# -*- coding: utf-8 -*-
import cv2
import numpy as np
import importlib
import os
import argparse
from PIL import Image
import torch
from torch.utils.data import DataLoader
from core.dataset import TestDataset
from core.metrics import calc_psnr_and_ssim, calculate_i3d_activations, calculate_vfid, init_i3d_model
# global variables
w, h = 432, 240
ref_length = 10
neighbor_stride = 5
default_fps = 24
# sample reference frames from the whole video
def get_ref_index(neighbor_ids, length):
ref_index = []
for i in range(0, length, ref_length):
if i not in neighbor_ids:
ref_index.append(i)
return ref_index
def main_worker(args):
args.size = (w, h)
# set up datasets and data loader
assert (args.dataset == 'davis') or args.dataset == 'youtube-vos', \
f"{args.dataset} dataset is not supported"
test_dataset = TestDataset(args)
test_loader = DataLoader(test_dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers)
# set up models
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = importlib.import_module('model.' + args.model)
model = net.InpaintGenerator().to(device)
data = torch.load(args.ckpt, map_location=device)
model.load_state_dict(data)
print(f'Loading from: {args.ckpt}')
model.eval()
total_frame_psnr = []
total_frame_ssim = []
output_i3d_activations = []
real_i3d_activations = []
print('Start evaluation...')
# create results directory
result_path = os.path.join('results', f'{args.model}_{args.dataset}')
if not os.path.exists(result_path):
os.makedirs(result_path)
eval_summary = open(
os.path.join(result_path, f"{args.model}_{args.dataset}_metrics.txt"),
"w")
i3d_model = init_i3d_model()
for index, items in enumerate(test_loader):
frames, masks, video_name, frames_PIL = items
video_length = frames.size(1)
frames, masks = frames.to(device), masks.to(device)
ori_frames = frames_PIL
ori_frames = [
ori_frames[i].squeeze().cpu().numpy() for i in range(video_length)
]
comp_frames = [None] * video_length
# complete holes by our model
for f in range(0, video_length, neighbor_stride):
neighbor_ids = [
i for i in range(max(0, f - neighbor_stride),
min(video_length, f + neighbor_stride + 1))
]
ref_ids = get_ref_index(neighbor_ids, video_length)
selected_imgs = frames[:1, neighbor_ids + ref_ids, :, :, :]
selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
with torch.no_grad():
masked_frames = selected_imgs * (1 - selected_masks)
pred_img, _ = model(masked_frames, len(neighbor_ids))
pred_img = (pred_img + 1) / 2
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
binary_masks = masks[0, neighbor_ids, :, :, :].cpu().permute(
0, 2, 3, 1).numpy().astype(np.uint8)
for i in range(len(neighbor_ids)):
idx = neighbor_ids[i]
img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
+ ori_frames[idx] * (1 - binary_masks[i])
if comp_frames[idx] is None:
comp_frames[idx] = img
else:
comp_frames[idx] = comp_frames[idx].astype(
np.float32) * 0.5 + img.astype(np.float32) * 0.5
# calculate metrics
cur_video_psnr = []
cur_video_ssim = []
comp_PIL = [] # to calculate VFID
frames_PIL = []
for ori, comp in zip(ori_frames, comp_frames):
psnr, ssim = calc_psnr_and_ssim(ori, comp)
cur_video_psnr.append(psnr)
cur_video_ssim.append(ssim)
total_frame_psnr.append(psnr)
total_frame_ssim.append(ssim)
frames_PIL.append(Image.fromarray(ori.astype(np.uint8)))
comp_PIL.append(Image.fromarray(comp.astype(np.uint8)))
cur_psnr = sum(cur_video_psnr) / len(cur_video_psnr)
cur_ssim = sum(cur_video_ssim) / len(cur_video_ssim)
# saving i3d activations
frames_i3d, comp_i3d = calculate_i3d_activations(frames_PIL,
comp_PIL,
i3d_model,
device=device)
real_i3d_activations.append(frames_i3d)
output_i3d_activations.append(comp_i3d)
print(
f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f}'
)
eval_summary.write(
f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f}\n'
)
# saving images for evaluating warpping errors
if args.save_results:
save_frame_path = os.path.join(result_path, video_name[0])
os.makedirs(save_frame_path, exist_ok=False)
for i, frame in enumerate(comp_frames):
cv2.imwrite(
os.path.join(save_frame_path,
str(i).zfill(5) + '.png'),
cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR))
avg_frame_psnr = sum(total_frame_psnr) / len(total_frame_psnr)
avg_frame_ssim = sum(total_frame_ssim) / len(total_frame_ssim)
fid_score = calculate_vfid(real_i3d_activations, output_i3d_activations)
print('Finish evaluation... Average Frame PSNR/SSIM/VFID: '
f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f}')
eval_summary.write(
'Finish evaluation... Average Frame PSNR/SSIM/VFID: '
f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f}')
eval_summary.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='E2FGVI')
parser.add_argument('--dataset',
choices=['davis', 'youtube-vos'],
type=str)
parser.add_argument('--data_root', type=str, required=True)
parser.add_argument('--model', choices=['e2fgvi', 'e2fgvi_hq'], type=str)
parser.add_argument('--ckpt', type=str, required=True)
parser.add_argument('--save_results', action='store_true', default=False)
parser.add_argument('--num_workers', default=4, type=int)
args = parser.parse_args()
main_worker(args)