Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import importlib | |
| import os | |
| import argparse | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| from matplotlib import animation | |
| import torch | |
| from core.utils import to_tensors | |
| parser = argparse.ArgumentParser(description="E2FGVI") | |
| parser.add_argument("-v", "--video", type=str, required=True) | |
| parser.add_argument("-c", "--ckpt", type=str, required=True) | |
| parser.add_argument("-m", "--mask", type=str, required=True) | |
| parser.add_argument("--model", type=str, choices=['e2fgvi', 'e2fgvi_hq']) | |
| parser.add_argument("--step", type=int, default=10) | |
| parser.add_argument("--num_ref", type=int, default=-1) | |
| parser.add_argument("--neighbor_stride", type=int, default=5) | |
| parser.add_argument("--savefps", type=int, default=24) | |
| # args for e2fgvi_hq (which can handle videos with arbitrary resolution) | |
| parser.add_argument("--set_size", action='store_true', default=False) | |
| parser.add_argument("--width", type=int) | |
| parser.add_argument("--height", type=int) | |
| args = parser.parse_args() | |
| ref_length = args.step # ref_step | |
| num_ref = args.num_ref | |
| neighbor_stride = args.neighbor_stride | |
| default_fps = args.savefps | |
| # sample reference frames from the whole video | |
| def get_ref_index(f, neighbor_ids, length): | |
| ref_index = [] | |
| if num_ref == -1: | |
| for i in range(0, length, ref_length): | |
| if i not in neighbor_ids: | |
| ref_index.append(i) | |
| else: | |
| start_idx = max(0, f - ref_length * (num_ref // 2)) | |
| end_idx = min(length, f + ref_length * (num_ref // 2)) | |
| for i in range(start_idx, end_idx + 1, ref_length): | |
| if i not in neighbor_ids: | |
| if len(ref_index) > num_ref: | |
| break | |
| ref_index.append(i) | |
| return ref_index | |
| # read frame-wise masks | |
| def read_mask(mpath, size): | |
| masks = [] | |
| mnames = os.listdir(mpath) | |
| mnames.sort() | |
| for mp in mnames: | |
| m = Image.open(os.path.join(mpath, mp)) | |
| m = m.resize(size, Image.NEAREST) | |
| m = np.array(m.convert('L')) | |
| m = np.array(m > 0).astype(np.uint8) | |
| m = cv2.dilate(m, | |
| cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)), | |
| iterations=4) | |
| masks.append(Image.fromarray(m * 255)) | |
| return masks | |
| # read frames from video | |
| def read_frame_from_videos(args): | |
| vname = args.video | |
| frames = [] | |
| if args.use_mp4: | |
| vidcap = cv2.VideoCapture(vname) | |
| success, image = vidcap.read() | |
| count = 0 | |
| while success: | |
| image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
| frames.append(image) | |
| success, image = vidcap.read() | |
| count += 1 | |
| else: | |
| lst = os.listdir(vname) | |
| lst.sort() | |
| fr_lst = [vname + '/' + name for name in lst] | |
| for fr in fr_lst: | |
| image = cv2.imread(fr) | |
| image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
| frames.append(image) | |
| return frames | |
| # resize frames | |
| def resize_frames(frames, size=None): | |
| if size is not None: | |
| frames = [f.resize(size) for f in frames] | |
| else: | |
| size = frames[0].size | |
| return frames, size | |
| def main_worker(): | |
| # set up models | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if args.model == "e2fgvi": | |
| size = (432, 240) | |
| elif args.set_size: | |
| size = (args.width, args.height) | |
| else: | |
| size = None | |
| net = importlib.import_module('model.' + args.model) | |
| model = net.InpaintGenerator().to(device) | |
| data = torch.load(args.ckpt, map_location=device) | |
| model.load_state_dict(data) | |
| print(f'Loading model from: {args.ckpt}') | |
| model.eval() | |
| # prepare datset | |
| args.use_mp4 = True if args.video.endswith('.mp4') else False | |
| print( | |
| f'Loading videos and masks from: {args.video} | INPUT MP4 format: {args.use_mp4}' | |
| ) | |
| frames = read_frame_from_videos(args) | |
| frames, size = resize_frames(frames, size) | |
| h, w = size[1], size[0] | |
| video_length = len(frames) | |
| imgs = to_tensors()(frames).unsqueeze(0) * 2 - 1 | |
| frames = [np.array(f).astype(np.uint8) for f in frames] | |
| masks = read_mask(args.mask, size) | |
| binary_masks = [ | |
| np.expand_dims((np.array(m) != 0).astype(np.uint8), 2) for m in masks | |
| ] | |
| masks = to_tensors()(masks).unsqueeze(0) | |
| imgs, masks = imgs.to(device), masks.to(device) | |
| comp_frames = [None] * video_length | |
| # completing holes by e2fgvi | |
| print(f'Start test...') | |
| for f in tqdm(range(0, video_length, neighbor_stride)): | |
| neighbor_ids = [ | |
| i for i in range(max(0, f - neighbor_stride), | |
| min(video_length, f + neighbor_stride + 1)) | |
| ] | |
| ref_ids = get_ref_index(f, neighbor_ids, video_length) | |
| selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :] | |
| selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :] | |
| with torch.no_grad(): | |
| masked_imgs = selected_imgs * (1 - selected_masks) | |
| mod_size_h = 60 | |
| mod_size_w = 108 | |
| h_pad = (mod_size_h - h % mod_size_h) % mod_size_h | |
| w_pad = (mod_size_w - w % mod_size_w) % mod_size_w | |
| masked_imgs = torch.cat( | |
| [masked_imgs, torch.flip(masked_imgs, [3])], | |
| 3)[:, :, :, :h + h_pad, :] | |
| masked_imgs = torch.cat( | |
| [masked_imgs, torch.flip(masked_imgs, [4])], | |
| 4)[:, :, :, :, :w + w_pad] | |
| pred_imgs, _ = model(masked_imgs, len(neighbor_ids)) | |
| pred_imgs = pred_imgs[:, :, :h, :w] | |
| pred_imgs = (pred_imgs + 1) / 2 | |
| pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255 | |
| for i in range(len(neighbor_ids)): | |
| idx = neighbor_ids[i] | |
| img = np.array(pred_imgs[i]).astype( | |
| np.uint8) * binary_masks[idx] + frames[idx] * ( | |
| 1 - binary_masks[idx]) | |
| if comp_frames[idx] is None: | |
| comp_frames[idx] = img | |
| else: | |
| comp_frames[idx] = comp_frames[idx].astype( | |
| np.float32) * 0.5 + img.astype(np.float32) * 0.5 | |
| # saving videos | |
| print('Saving videos...') | |
| save_dir_name = 'results' | |
| ext_name = '_results.mp4' | |
| save_base_name = args.video.split('/')[-1] | |
| save_name = save_base_name.replace( | |
| '.mp4', ext_name) if args.use_mp4 else save_base_name + ext_name | |
| if not os.path.exists(save_dir_name): | |
| os.makedirs(save_dir_name) | |
| save_path = os.path.join(save_dir_name, save_name) | |
| writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), | |
| default_fps, size) | |
| for f in range(video_length): | |
| comp = comp_frames[f].astype(np.uint8) | |
| writer.write(cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)) | |
| writer.release() | |
| print(f'Finish test! The result video is saved in: {save_path}.') | |
| # show results | |
| print('Let us enjoy the result!') | |
| fig = plt.figure('Let us enjoy the result') | |
| ax1 = fig.add_subplot(1, 2, 1) | |
| ax1.axis('off') | |
| ax1.set_title('Original Video') | |
| ax2 = fig.add_subplot(1, 2, 2) | |
| ax2.axis('off') | |
| ax2.set_title('Our Result') | |
| imdata1 = ax1.imshow(frames[0]) | |
| imdata2 = ax2.imshow(comp_frames[0].astype(np.uint8)) | |
| def update(idx): | |
| imdata1.set_data(frames[idx]) | |
| imdata2.set_data(comp_frames[idx].astype(np.uint8)) | |
| fig.tight_layout() | |
| anim = animation.FuncAnimation(fig, | |
| update, | |
| frames=len(frames), | |
| interval=50) | |
| plt.show() | |
| if __name__ == '__main__': | |
| main_worker() | |