| |
| import os |
| from types import SimpleNamespace |
| import cv2 |
| import torch |
| import shutil |
| import numpy as np |
| from tqdm import tqdm |
| from torch.nn import functional as F |
| import warnings |
| import _thread |
| from queue import Queue |
| import time |
| from .model.pytorch_msssim import ssim_matlab |
|
|
| from deforum_helpers.video_audio_utilities import ffmpeg_stitch_video |
| from deforum_helpers.general_utils import duplicate_pngs_from_folder |
|
|
| warnings.filterwarnings("ignore") |
|
|
| def run_rife_new_video_infer( |
| output=None, |
| model=None, |
| fp16=False, |
| UHD=False, |
| scale=1.0, |
| fps=None, |
| deforum_models_path=None, |
| raw_output_imgs_path=None, |
| img_batch_id=None, |
| ffmpeg_location=None, |
| audio_track=None, |
| interp_x_amount=2, |
| slow_mo_enabled=False, |
| slow_mo_x_amount=2, |
| ffmpeg_crf=17, |
| ffmpeg_preset='veryslow', |
| keep_imgs=False, |
| orig_vid_name = None, |
| srt_path = None): |
|
|
| args = SimpleNamespace() |
| args.output = output |
| args.modelDir = model |
| args.fp16 = fp16 |
| args.UHD = UHD |
| args.scale = scale |
| args.fps = fps |
| args.deforum_models_path = deforum_models_path |
| args.raw_output_imgs_path = raw_output_imgs_path |
| args.img_batch_id = img_batch_id |
| args.ffmpeg_location = ffmpeg_location |
| args.audio_track = audio_track |
| args.interp_x_amount = interp_x_amount |
| args.slow_mo_enabled = slow_mo_enabled |
| args.slow_mo_x_amount = slow_mo_x_amount |
| args.ffmpeg_crf = ffmpeg_crf |
| args.ffmpeg_preset = ffmpeg_preset |
| args.keep_imgs = keep_imgs |
| args.orig_vid_name = orig_vid_name |
|
|
| if args.UHD and args.scale == 1.0: |
| args.scale = 0.5 |
| |
| start_time = time.time() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch.set_grad_enabled(False) |
| if torch.cuda.is_available(): |
| torch.backends.cudnn.enabled = True |
| torch.backends.cudnn.benchmark = True |
| |
| if (args.fp16): |
| torch.set_default_tensor_type(torch.cuda.HalfTensor) |
| if args.modelDir is not None: |
| try: |
| from .rife_new_gen.RIFE_HDv3 import Model |
| except ImportError as e: |
| raise ValueError(f"{args.modelDir} could not be found. Please contact deforum support {e}") |
| except Exception as e: |
| raise ValueError(f"An error occured while trying to import {args.modelDir}: {e}") |
| else: |
| print("Got a request to frame-interpolate but no valid frame interpolation engine value provided. Doing... nothing") |
| return |
| |
| model = Model() |
| if not hasattr(model, 'version'): |
| model.version = 0 |
| model.load_model(args.modelDir, -1, deforum_models_path) |
| model.eval() |
| model.device() |
| |
| print(f"{args.modelDir}.pkl model successfully loaded into memory") |
| print("Interpolation progress (it's OK if it finishes before 100%):") |
| |
| interpolated_path = os.path.join(args.raw_output_imgs_path, 'interpolated_frames_rife') |
| |
| if args.orig_vid_name is not None: |
| custom_interp_path = "{}_{}".format(interpolated_path, args.orig_vid_name) |
| else: |
| custom_interp_path = "{}_{}".format(interpolated_path, args.img_batch_id) |
|
|
| |
| |
| temp_convert_raw_png_path = os.path.join(args.raw_output_imgs_path, "tmp_rife_folder") |
| |
| duplicate_pngs_from_folder(args.raw_output_imgs_path, temp_convert_raw_png_path, args.img_batch_id, args.orig_vid_name) |
| |
| videogen = [] |
| for f in os.listdir(temp_convert_raw_png_path): |
| |
| if '_depth_' not in f: |
| videogen.append(f) |
| tot_frame = len(videogen) |
| videogen.sort(key= lambda x:int(x.split('.')[0])) |
| img_path = os.path.join(temp_convert_raw_png_path, videogen[0]) |
| lastframe = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() |
| videogen = videogen[1:] |
| h, w, _ = lastframe.shape |
| vid_out = None |
|
|
| if not os.path.exists(custom_interp_path): |
| os.mkdir(custom_interp_path) |
|
|
| tmp = max(128, int(128 / args.scale)) |
| ph = ((h - 1) // tmp + 1) * tmp |
| pw = ((w - 1) // tmp + 1) * tmp |
| padding = (0, pw - w, 0, ph - h) |
| pbar = tqdm(total=tot_frame) |
|
|
| write_buffer = Queue(maxsize=500) |
| read_buffer = Queue(maxsize=500) |
| |
| _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen, temp_convert_raw_png_path)) |
| _thread.start_new_thread(clear_write_buffer, (args, write_buffer, custom_interp_path)) |
|
|
| I1 = torch.from_numpy(np.transpose(lastframe, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. |
| I1 = pad_image(I1, args.fp16, padding) |
| temp = None |
| |
| while True: |
| if temp is not None: |
| frame = temp |
| temp = None |
| else: |
| frame = read_buffer.get() |
| if frame is None: |
| break |
| I0 = I1 |
| I1 = torch.from_numpy(np.transpose(frame, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. |
| I1 = pad_image(I1, args.fp16, padding) |
| I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) |
| I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) |
| ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) |
|
|
| break_flag = False |
| if ssim > 0.996: |
| frame = read_buffer.get() |
| if frame is None: |
| break_flag = True |
| frame = lastframe |
| else: |
| temp = frame |
| I1 = torch.from_numpy(np.transpose(frame, (2, 0, 1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. |
| I1 = pad_image(I1, args.fp16, padding) |
| I1 = model.inference(I0, I1, args.scale) |
| I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) |
| ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) |
| frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w] |
|
|
| if ssim < 0.2: |
| output = [] |
| for i in range(args.interp_x_amount - 1): |
| output.append(I0) |
| else: |
| output = make_inference(model, I0, I1, args.interp_x_amount - 1, scale) |
| |
| write_buffer.put(lastframe) |
| for mid in output: |
| mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0))) |
| write_buffer.put(mid[:h, :w]) |
| pbar.update(1) |
| lastframe = frame |
| if break_flag: |
| break |
|
|
| write_buffer.put(lastframe) |
|
|
| while (not write_buffer.empty()): |
| time.sleep(0.1) |
| pbar.close() |
| shutil.rmtree(temp_convert_raw_png_path) |
| |
| print(f"Interpolation \033[0;32mdone\033[0m in {time.time()-start_time:.2f} seconds!") |
| |
| try: |
| print (f"*Passing interpolated frames to ffmpeg...*") |
| vid_out_path = stitch_video(args.img_batch_id, args.fps, custom_interp_path, args.audio_track, args.ffmpeg_location, args.interp_x_amount, args.slow_mo_enabled, args.slow_mo_x_amount, args.ffmpeg_crf, args.ffmpeg_preset, args.keep_imgs, args.orig_vid_name, srt_path=srt_path) |
| |
| if orig_vid_name is not None: |
| shutil.rmtree(raw_output_imgs_path) |
| except Exception as e: |
| print(f'Video stitching gone wrong. *Interpolated frames were saved to HD as backup!*. Actual error: {e}') |
| |
| def clear_write_buffer(user_args, write_buffer, custom_interp_path): |
| cnt = 0 |
|
|
| while True: |
| item = write_buffer.get() |
| if item is None: |
| break |
| filename = '{}/{:0>9d}.png'.format(custom_interp_path, cnt) |
|
|
| cv2.imwrite(filename, item[:, :, ::-1]) |
|
|
| cnt += 1 |
|
|
| def build_read_buffer(user_args, read_buffer, videogen, temp_convert_raw_png_path): |
| for frame in videogen: |
| if not temp_convert_raw_png_path is None: |
| img_path = os.path.join(temp_convert_raw_png_path, frame) |
| frame = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() |
| read_buffer.put(frame) |
| read_buffer.put(None) |
|
|
| def make_inference(model, I0, I1, n, scale): |
| if model.version >= 3.9: |
| res = [] |
| for i in range(n): |
| res.append(model.inference(I0, I1, (i + 1) * 1. / (n + 1), scale)) |
| return res |
| else: |
| middle = model.inference(I0, I1, scale) |
| if n == 1: |
| return [middle] |
| first_half = make_inference(model, I0, middle, n=n // 2, scale=scale) |
| second_half = make_inference(model, middle, I1, n=n // 2, scale=scale) |
| if n % 2: |
| return [*first_half, middle, *second_half] |
| else: |
| return [*first_half, *second_half] |
|
|
| def pad_image(img, fp16, padding): |
| if (fp16): |
| return F.pad(img, padding).half() |
| else: |
| return F.pad(img, padding) |
|
|
| |
| def stitch_video(img_batch_id, fps, img_folder_path, audio_path, ffmpeg_location, interp_x_amount, slow_mo_enabled, slow_mo_x_amount, f_crf, f_preset, keep_imgs, orig_vid_name, srt_path=None): |
| parent_folder = os.path.dirname(img_folder_path) |
| grandparent_folder = os.path.dirname(parent_folder) |
| if orig_vid_name is not None: |
| mp4_path = os.path.join(grandparent_folder, str(orig_vid_name) +'_RIFE_' + 'x' + str(interp_x_amount)) |
| else: |
| mp4_path = os.path.join(parent_folder, str(img_batch_id) +'_RIFE_' + 'x' + str(interp_x_amount)) |
|
|
| if slow_mo_enabled: |
| mp4_path = mp4_path + '_slomo_x' + str(slow_mo_x_amount) |
| mp4_path = mp4_path + '.mp4' |
|
|
| t = os.path.join(img_folder_path, "%09d.png") |
| add_soundtrack = 'None' |
| if not audio_path is None: |
| add_soundtrack = 'File' |
| |
| exception_raised = False |
| try: |
| ffmpeg_stitch_video(ffmpeg_location=ffmpeg_location, fps=fps, outmp4_path=mp4_path, stitch_from_frame=0, stitch_to_frame=1000000, imgs_path=t, add_soundtrack=add_soundtrack, audio_path=audio_path, crf=f_crf, preset=f_preset, srt_path=srt_path) |
| except Exception as e: |
| exception_raised = True |
| print(f"An error occurred while stitching the video: {e}") |
|
|
| if not exception_raised and not keep_imgs: |
| shutil.rmtree(img_folder_path) |
|
|
| if (keep_imgs and orig_vid_name is not None) or (orig_vid_name is not None and exception_raised is True): |
| shutil.move(img_folder_path, grandparent_folder) |
|
|
| return mp4_path |