Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import platform | |
| import struct | |
| import subprocess | |
| import time | |
| from typing import List | |
| import cv2 | |
| import numpy as np | |
| import torch.multiprocessing as mp | |
| from numba import njit | |
| import sys | |
| sys.path.append("./src/ebsynth/") | |
| import blender.histogram_blend as histogram_blend | |
| from blender.guide import (BaseGuide, ColorGuide, EdgeGuide, PositionalGuide, | |
| TemporalGuide) | |
| from blender.poisson_fusion import poisson_fusion | |
| from blender.video_sequence import VideoSequence | |
| from flow.flow_utils import flow_calc | |
| from src.video_util import frame_to_video | |
| OPEN_EBSYNTH_LOG = False | |
| MAX_PROCESS = 8 | |
| os_str = platform.system() | |
| if os_str == 'Windows': | |
| ebsynth_bin = '.\\src\\ebsynth\\deps\\ebsynth\\bin\\ebsynth.exe' | |
| elif os_str == 'Linux': | |
| ebsynth_bin = './src/ebsynth/deps/ebsynth/bin/ebsynth' | |
| elif os_str == 'Darwin': | |
| ebsynth_bin = './src/ebsynth/deps/ebsynth/bin/ebsynth.app' | |
| else: | |
| print('Cannot recognize OS. Run Ebsynth failed.') | |
| exit(0) | |
| def g_error_mask_loop(H, W, dist1, dist2, output, weight1, weight2): | |
| for i in range(H): | |
| for j in range(W): | |
| if weight1 * dist1[i, j] < weight2 * dist2[i, j]: | |
| output[i, j] = 0 | |
| else: | |
| output[i, j] = 1 | |
| if weight1 == 0: | |
| output[i, j] = 0 | |
| elif weight2 == 0: | |
| output[i, j] = 1 | |
| def g_error_mask(dist1, dist2, weight1=1, weight2=1): | |
| H, W = dist1.shape | |
| output = np.empty_like(dist1, dtype=np.byte) | |
| g_error_mask_loop(H, W, dist1, dist2, output, weight1, weight2) | |
| return output | |
| def create_sequence(base_dir, key_ind, key_dir): | |
| sequence = VideoSequence(base_dir, key_ind, 'video', key_dir, | |
| 'tmp', '%04d.png', '%04d.png') | |
| return sequence | |
| def process_one_sequence(i, video_sequence: VideoSequence): | |
| interval = video_sequence.interval(i) | |
| for is_forward in [True, False]: | |
| input_seq = video_sequence.get_input_sequence(i, is_forward) | |
| output_seq = video_sequence.get_output_sequence(i, is_forward) | |
| flow_seq = video_sequence.get_flow_sequence(i, is_forward) | |
| key_img_id = i if is_forward else i + 1 | |
| key_img = video_sequence.get_key_img(key_img_id) | |
| for j in range(interval - 1): | |
| i1 = cv2.imread(input_seq[j]) | |
| i2 = cv2.imread(input_seq[j + 1]) | |
| flow_calc.get_flow(i1, i2, flow_seq[j]) | |
| guides: List[BaseGuide] = [ | |
| ColorGuide(input_seq), | |
| EdgeGuide(input_seq, | |
| video_sequence.get_edge_sequence(i, is_forward)), | |
| TemporalGuide(key_img, output_seq, flow_seq, | |
| video_sequence.get_temporal_sequence(i, is_forward)), | |
| PositionalGuide(flow_seq, | |
| video_sequence.get_pos_sequence(i, is_forward)) | |
| ] | |
| weights = [6, 0.5, 0.5, 2] | |
| for j in range(interval): | |
| # key frame | |
| if j == 0: | |
| img = cv2.imread(key_img) | |
| cv2.imwrite(output_seq[0], img) | |
| else: | |
| cmd = f'{ebsynth_bin} -style {os.path.abspath(key_img)}' | |
| for g, w in zip(guides, weights): | |
| cmd += ' ' + g.get_cmd(j, w) | |
| cmd += (f' -output {os.path.abspath(output_seq[j])}' | |
| ' -searchvoteiters 12 -patchmatchiters 6') | |
| if OPEN_EBSYNTH_LOG: | |
| print(cmd) | |
| subprocess.run(cmd, | |
| shell=True, | |
| capture_output=not OPEN_EBSYNTH_LOG) | |
| def process_sequences(i_arr, video_sequence: VideoSequence): | |
| for i in i_arr: | |
| process_one_sequence(i, video_sequence) | |
| def run_ebsynth(video_sequence: VideoSequence): | |
| beg = time.time() | |
| processes = [] | |
| mp.set_start_method('spawn') | |
| n_process = min(MAX_PROCESS, video_sequence.n_seq) | |
| cnt = video_sequence.n_seq // n_process | |
| remainder = video_sequence.n_seq % n_process | |
| prev_idx = 0 | |
| for i in range(n_process): | |
| task_cnt = cnt + 1 if i < remainder else cnt | |
| i_arr = list(range(prev_idx, prev_idx + task_cnt)) | |
| prev_idx += task_cnt | |
| p = mp.Process(target=process_sequences, args=(i_arr, video_sequence)) | |
| p.start() | |
| processes.append(p) | |
| for p in processes: | |
| p.join() | |
| end = time.time() | |
| print(f'ebsynth: {end-beg}') | |
| def assemble_min_error_img_loop(H, W, a, b, error_mask, out): | |
| for i in range(H): | |
| for j in range(W): | |
| if error_mask[i, j] == 0: | |
| out[i, j] = a[i, j] | |
| else: | |
| out[i, j] = b[i, j] | |
| def assemble_min_error_img(a, b, error_mask): | |
| H, W = a.shape[0:2] | |
| out = np.empty_like(a) | |
| assemble_min_error_img_loop(H, W, a, b, error_mask, out) | |
| return out | |
| def load_error(bin_path, img_shape): | |
| img_size = img_shape[0] * img_shape[1] | |
| with open(bin_path, 'rb') as fp: | |
| bytes = fp.read() | |
| read_size = struct.unpack('q', bytes[:8]) | |
| assert read_size[0] == img_size | |
| float_res = struct.unpack('f' * img_size, bytes[8:]) | |
| res = np.array(float_res, | |
| dtype=np.float32).reshape(img_shape[0], img_shape[1]) | |
| return res | |
| def process_seq(video_sequence: VideoSequence, | |
| i, | |
| blend_histogram=True, | |
| blend_gradient=True): | |
| key1_img = cv2.imread(video_sequence.get_key_img(i)) | |
| img_shape = key1_img.shape | |
| interval = video_sequence.interval(i) | |
| beg_id = video_sequence.get_sequence_beg_id(i) | |
| oas = video_sequence.get_output_sequence(i) | |
| obs = video_sequence.get_output_sequence(i, False) | |
| binas = [x.replace('jpg', 'bin') for x in oas] | |
| binbs = [x.replace('jpg', 'bin') for x in obs] | |
| obs = [obs[0]] + list(reversed(obs[1:])) | |
| inputs = video_sequence.get_input_sequence(i) | |
| oas = [cv2.imread(x) for x in oas] | |
| obs = [cv2.imread(x) for x in obs] | |
| inputs = [cv2.imread(x) for x in inputs] | |
| flow_seq = video_sequence.get_flow_sequence(i) | |
| dist1s = [] | |
| dist2s = [] | |
| for i in range(interval - 1): | |
| bin_a = binas[i + 1] | |
| bin_b = binbs[i + 1] | |
| dist1s.append(load_error(bin_a, img_shape)) | |
| dist2s.append(load_error(bin_b, img_shape)) | |
| lb = 0 | |
| ub = 1 | |
| beg = time.time() | |
| p_mask = None | |
| # write key img | |
| blend_out_path = video_sequence.get_blending_img(beg_id) | |
| cv2.imwrite(blend_out_path, key1_img) | |
| for i in range(interval - 1): | |
| c_id = beg_id + i + 1 | |
| blend_out_path = video_sequence.get_blending_img(c_id) | |
| dist1 = dist1s[i] | |
| dist2 = dist2s[i] | |
| oa = oas[i + 1] | |
| ob = obs[i + 1] | |
| weight1 = i / (interval - 1) * (ub - lb) + lb | |
| weight2 = 1 - weight1 | |
| mask = g_error_mask(dist1, dist2, weight1, weight2) | |
| if p_mask is not None: | |
| flow_path = flow_seq[i] | |
| flow = flow_calc.get_flow(inputs[i], inputs[i + 1], flow_path) | |
| p_mask = flow_calc.warp(p_mask, flow, 'nearest') | |
| mask = p_mask | mask | |
| p_mask = mask | |
| # Save tmp mask | |
| # out_mask = np.expand_dims(mask, 2) | |
| # cv2.imwrite(f'mask/mask_{c_id:04d}.jpg', out_mask * 255) | |
| min_error_img = assemble_min_error_img(oa, ob, mask) | |
| if blend_histogram: | |
| hb_res = histogram_blend.blend(oa, ob, min_error_img, | |
| (1 - weight1), (1 - weight2)) | |
| else: | |
| # hb_res = min_error_img | |
| tmpa = oa.astype(np.float32) | |
| tmpb = ob.astype(np.float32) | |
| hb_res = (1 - weight1) * tmpa + (1 - weight2) * tmpb | |
| # cv2.imwrite(blend_out_path, hb_res) | |
| # gradient blend | |
| if blend_gradient: | |
| res = poisson_fusion(hb_res, oa, ob, mask) | |
| else: | |
| res = hb_res | |
| cv2.imwrite(blend_out_path, res) | |
| end = time.time() | |
| print('others:', end - beg) | |
| def main(args): | |
| global MAX_PROCESS | |
| MAX_PROCESS = args.n_proc | |
| video_sequence = create_sequence(f'{args.name}', args.key_ind, args.key) | |
| if not args.ne: | |
| run_ebsynth(video_sequence) | |
| blend_histogram = True | |
| blend_gradient = args.ps | |
| for i in range(video_sequence.n_seq): | |
| process_seq(video_sequence, i, blend_histogram, blend_gradient) | |
| if args.output: | |
| frame_to_video(args.output, video_sequence.blending_dir, args.fps, | |
| False) | |
| if not args.tmp: | |
| video_sequence.remove_out_and_tmp() | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('name', type=str, help='Path to input video') | |
| parser.add_argument('--output', | |
| type=str, | |
| default=None, | |
| help='Path to output video') | |
| parser.add_argument('--fps', | |
| type=float, | |
| default=30, | |
| help='The FPS of output video') | |
| parser.add_argument("--key_ind", type=int, nargs='+', default=[1], help="key frame index") | |
| parser.add_argument('--key', | |
| type=str, | |
| default='keys0', | |
| help='The subfolder name of stylized key frames') | |
| parser.add_argument('--n_proc', | |
| type=int, | |
| default=8, | |
| help='The max process count') | |
| parser.add_argument('-ps', | |
| action='store_true', | |
| help='Use poisson gradient blending') | |
| parser.add_argument( | |
| '-ne', | |
| action='store_true', | |
| help='Do not run ebsynth (use previous ebsynth output)') | |
| parser.add_argument('-tmp', | |
| action='store_true', | |
| help='Keep temporary output') | |
| args = parser.parse_args() | |
| main(args) | |