| import os |
| import shutil |
| from enum import Enum |
|
|
| import cv2 |
| import einops |
| import gradio as gr |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torchvision.transforms as T |
| from blendmodes.blend import BlendType, blendLayers |
| from PIL import Image |
| from pytorch_lightning import seed_everything |
| from safetensors.torch import load_file |
| from skimage import exposure |
|
|
| import src.import_util |
| from ControlNet.annotator.canny import CannyDetector |
| from ControlNet.annotator.hed import HEDdetector |
| from ControlNet.annotator.midas import MidasDetector |
| from ControlNet.annotator.util import HWC3 |
| from ControlNet.cldm.model import create_model, load_state_dict |
| from gmflow_module.gmflow.gmflow import GMFlow |
| from flow.flow_utils import get_warped_and_mask |
| from sd_model_cfg import model_dict |
| from src.config import RerenderConfig |
| from src.controller import AttentionControl |
| from src.ddim_v_hacked import DDIMVSampler |
| from src.img_util import find_flat_region, numpy2tensor |
| from src.video_util import (frame_to_video, get_fps, get_frame_count, |
| prepare_frames) |
|
|
| import huggingface_hub |
|
|
| REPO_NAME = 'Anonymous-sub/Rerender' |
|
|
| huggingface_hub.hf_hub_download(REPO_NAME, |
| 'pexels-koolshooters-7322716.mp4', |
| local_dir='videos') |
| huggingface_hub.hf_hub_download( |
| REPO_NAME, |
| 'pexels-antoni-shkraba-8048492-540x960-25fps.mp4', |
| local_dir='videos') |
| huggingface_hub.hf_hub_download( |
| REPO_NAME, |
| 'pexels-cottonbro-studio-6649832-960x506-25fps.mp4', |
| local_dir='videos') |
|
|
| inversed_model_dict = dict() |
| for k, v in model_dict.items(): |
| inversed_model_dict[v] = k |
|
|
| to_tensor = T.PILToTensor() |
| blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18)) |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
| class ProcessingState(Enum): |
| NULL = 0 |
| FIRST_IMG = 1 |
| KEY_IMGS = 2 |
|
|
|
|
| MAX_KEYFRAME = float(os.environ.get('MAX_KEYFRAME', 8)) |
|
|
|
|
| class GlobalState: |
|
|
| def __init__(self): |
| self.sd_model = None |
| self.ddim_v_sampler = None |
| self.detector_type = None |
| self.detector = None |
| self.controller = None |
| self.processing_state = ProcessingState.NULL |
| flow_model = GMFlow( |
| feature_channels=128, |
| num_scales=1, |
| upsample_factor=8, |
| num_head=1, |
| attention_type='swin', |
| ffn_dim_expansion=4, |
| num_transformer_layers=6, |
| ).to(device) |
|
|
| checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth', |
| map_location=lambda storage, loc: storage) |
| weights = checkpoint['model'] if 'model' in checkpoint else checkpoint |
| flow_model.load_state_dict(weights, strict=False) |
| flow_model.eval() |
| self.flow_model = flow_model |
|
|
| def update_controller(self, inner_strength, mask_period, cross_period, |
| ada_period, warp_period): |
| self.controller = AttentionControl(inner_strength, mask_period, |
| cross_period, ada_period, |
| warp_period) |
|
|
| def update_sd_model(self, sd_model, control_type): |
| if sd_model == self.sd_model: |
| return |
| self.sd_model = sd_model |
| model = create_model('./ControlNet/models/cldm_v15.yaml').cpu() |
| if control_type == 'HED': |
| model.load_state_dict( |
| load_state_dict(huggingface_hub.hf_hub_download( |
| 'lllyasviel/ControlNet', './models/control_sd15_hed.pth'), |
| location=device)) |
| elif control_type == 'canny': |
| model.load_state_dict( |
| load_state_dict(huggingface_hub.hf_hub_download( |
| 'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'), |
| location=device)) |
| elif control_type == 'depth': |
| model.load_state_dict( |
| load_state_dict(huggingface_hub.hf_hub_download( |
| 'lllyasviel/ControlNet', 'models/control_sd15_depth.pth'), |
| location=device)) |
|
|
| model.to(device) |
| sd_model_path = model_dict[sd_model] |
| if len(sd_model_path) > 0: |
| repo_name = REPO_NAME |
| |
| if sd_model.count('/') == 1: |
| repo_name = sd_model |
|
|
| model_ext = os.path.splitext(sd_model_path)[1] |
| downloaded_model = huggingface_hub.hf_hub_download( |
| repo_name, sd_model_path) |
| if model_ext == '.safetensors': |
| model.load_state_dict(load_file(downloaded_model), |
| strict=False) |
| elif model_ext == '.ckpt' or model_ext == '.pth': |
| model.load_state_dict( |
| torch.load(downloaded_model)['state_dict'], strict=False) |
|
|
| try: |
| model.first_stage_model.load_state_dict(torch.load( |
| huggingface_hub.hf_hub_download( |
| 'stabilityai/sd-vae-ft-mse-original', |
| 'vae-ft-mse-840000-ema-pruned.ckpt'))['state_dict'], |
| strict=False) |
| except Exception: |
| print('Warning: We suggest you download the fine-tuned VAE', |
| 'otherwise the generation quality will be degraded') |
|
|
| self.ddim_v_sampler = DDIMVSampler(model) |
|
|
| def clear_sd_model(self): |
| self.sd_model = None |
| self.ddim_v_sampler = None |
| if device == 'cuda': |
| torch.cuda.empty_cache() |
|
|
| def update_detector(self, control_type, canny_low=100, canny_high=200): |
| if self.detector_type == control_type: |
| return |
| if control_type == 'HED': |
| self.detector = HEDdetector() |
| elif control_type == 'canny': |
| canny_detector = CannyDetector() |
| low_threshold = canny_low |
| high_threshold = canny_high |
|
|
| def apply_canny(x): |
| return canny_detector(x, low_threshold, high_threshold) |
|
|
| self.detector = apply_canny |
|
|
| elif control_type == 'depth': |
| midas = MidasDetector() |
|
|
| def apply_midas(x): |
| detected_map, _ = midas(x) |
| return detected_map |
|
|
| self.detector = apply_midas |
|
|
|
|
| global_state = GlobalState() |
| global_video_path = None |
| video_frame_count = None |
|
|
|
|
| def create_cfg(input_path, prompt, image_resolution, control_strength, |
| color_preserve, left_crop, right_crop, top_crop, bottom_crop, |
| control_type, low_threshold, high_threshold, ddim_steps, scale, |
| seed, sd_model, a_prompt, n_prompt, interval, keyframe_count, |
| x0_strength, use_constraints, cross_start, cross_end, |
| style_update_freq, warp_start, warp_end, mask_start, mask_end, |
| ada_start, ada_end, mask_strength, inner_strength, |
| smooth_boundary): |
| use_warp = 'shape-aware fusion' in use_constraints |
| use_mask = 'pixel-aware fusion' in use_constraints |
| use_ada = 'color-aware AdaIN' in use_constraints |
|
|
| if not use_warp: |
| warp_start = 1 |
| warp_end = 0 |
|
|
| if not use_mask: |
| mask_start = 1 |
| mask_end = 0 |
|
|
| if not use_ada: |
| ada_start = 1 |
| ada_end = 0 |
|
|
| input_name = os.path.split(input_path)[-1].split('.')[0] |
| frame_count = 2 + keyframe_count * interval |
| cfg = RerenderConfig() |
| cfg.create_from_parameters( |
| input_path, |
| os.path.join('result', input_name, 'blend.mp4'), |
| prompt, |
| a_prompt=a_prompt, |
| n_prompt=n_prompt, |
| frame_count=frame_count, |
| interval=interval, |
| crop=[left_crop, right_crop, top_crop, bottom_crop], |
| sd_model=sd_model, |
| ddim_steps=ddim_steps, |
| scale=scale, |
| control_type=control_type, |
| control_strength=control_strength, |
| canny_low=low_threshold, |
| canny_high=high_threshold, |
| seed=seed, |
| image_resolution=image_resolution, |
| x0_strength=x0_strength, |
| style_update_freq=style_update_freq, |
| cross_period=(cross_start, cross_end), |
| warp_period=(warp_start, warp_end), |
| mask_period=(mask_start, mask_end), |
| ada_period=(ada_start, ada_end), |
| mask_strength=mask_strength, |
| inner_strength=inner_strength, |
| smooth_boundary=smooth_boundary, |
| color_preserve=color_preserve) |
| return cfg |
|
|
|
|
| def cfg_to_input(filename): |
|
|
| cfg = RerenderConfig() |
| cfg.create_from_path(filename) |
| keyframe_count = (cfg.frame_count - 2) // cfg.interval |
| use_constraints = [ |
| 'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN' |
| ] |
|
|
| sd_model = inversed_model_dict.get(cfg.sd_model, 'Stable Diffusion 1.5') |
|
|
| args = [ |
| cfg.input_path, cfg.prompt, cfg.image_resolution, cfg.control_strength, |
| cfg.color_preserve, *cfg.crop, cfg.control_type, cfg.canny_low, |
| cfg.canny_high, cfg.ddim_steps, cfg.scale, cfg.seed, sd_model, |
| cfg.a_prompt, cfg.n_prompt, cfg.interval, keyframe_count, |
| cfg.x0_strength, use_constraints, *cfg.cross_period, |
| cfg.style_update_freq, *cfg.warp_period, *cfg.mask_period, |
| *cfg.ada_period, cfg.mask_strength, cfg.inner_strength, |
| cfg.smooth_boundary |
| ] |
| return args |
|
|
|
|
| def setup_color_correction(image): |
| correction_target = cv2.cvtColor(np.asarray(image.copy()), |
| cv2.COLOR_RGB2LAB) |
| return correction_target |
|
|
|
|
| def apply_color_correction(correction, original_image): |
| image = Image.fromarray( |
| cv2.cvtColor( |
| exposure.match_histograms(cv2.cvtColor(np.asarray(original_image), |
| cv2.COLOR_RGB2LAB), |
| correction, |
| channel_axis=2), |
| cv2.COLOR_LAB2RGB).astype('uint8')) |
|
|
| image = blendLayers(image, original_image, BlendType.LUMINOSITY) |
|
|
| return image |
|
|
|
|
| @torch.no_grad() |
| def process(*args): |
| first_frame = process1(*args) |
|
|
| keypath = process2(*args) |
|
|
| return first_frame, keypath |
|
|
|
|
| @torch.no_grad() |
| def process0(*args): |
| global global_video_path |
| global_video_path = args[0] |
| return process(*args[1:]) |
|
|
|
|
| @torch.no_grad() |
| def process1(*args): |
|
|
| global global_video_path |
| cfg = create_cfg(global_video_path, *args) |
| global global_state |
| global_state.update_sd_model(cfg.sd_model, cfg.control_type) |
| global_state.update_controller(cfg.inner_strength, cfg.mask_period, |
| cfg.cross_period, cfg.ada_period, |
| cfg.warp_period) |
| global_state.update_detector(cfg.control_type, cfg.canny_low, |
| cfg.canny_high) |
| global_state.processing_state = ProcessingState.FIRST_IMG |
|
|
| prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution, |
| cfg.crop) |
|
|
| ddim_v_sampler = global_state.ddim_v_sampler |
| model = ddim_v_sampler.model |
| detector = global_state.detector |
| controller = global_state.controller |
| model.control_scales = [cfg.control_strength] * 13 |
| model.to(device) |
|
|
| num_samples = 1 |
| eta = 0.0 |
| imgs = sorted(os.listdir(cfg.input_dir)) |
| imgs = [os.path.join(cfg.input_dir, img) for img in imgs] |
|
|
| model.cond_stage_model.device = device |
|
|
| with torch.no_grad(): |
| frame = cv2.imread(imgs[0]) |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| img = HWC3(frame) |
| H, W, C = img.shape |
|
|
| img_ = numpy2tensor(img) |
|
|
| def generate_first_img(img_, strength): |
| encoder_posterior = model.encode_first_stage(img_.to(device)) |
| x0 = model.get_first_stage_encoding(encoder_posterior).detach() |
|
|
| detected_map = detector(img) |
| detected_map = HWC3(detected_map) |
|
|
| control = torch.from_numpy( |
| detected_map.copy()).float().to(device) / 255.0 |
| control = torch.stack([control for _ in range(num_samples)], dim=0) |
| control = einops.rearrange(control, 'b h w c -> b c h w').clone() |
| cond = { |
| 'c_concat': [control], |
| 'c_crossattn': [ |
| model.get_learned_conditioning( |
| [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) |
| ] |
| } |
| un_cond = { |
| 'c_concat': [control], |
| 'c_crossattn': |
| [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] |
| } |
| shape = (4, H // 8, W // 8) |
|
|
| controller.set_task('initfirst') |
| seed_everything(cfg.seed) |
|
|
| samples, _ = ddim_v_sampler.sample( |
| cfg.ddim_steps, |
| num_samples, |
| shape, |
| cond, |
| verbose=False, |
| eta=eta, |
| unconditional_guidance_scale=cfg.scale, |
| unconditional_conditioning=un_cond, |
| controller=controller, |
| x0=x0, |
| strength=strength) |
| x_samples = model.decode_first_stage(samples) |
| x_samples_np = ( |
| einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + |
| 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
| return x_samples, x_samples_np |
|
|
| |
| |
| if not cfg.color_preserve: |
| first_strength = -1 |
| else: |
| first_strength = 1 - cfg.x0_strength |
|
|
| x_samples, x_samples_np = generate_first_img(img_, first_strength) |
|
|
| if not cfg.color_preserve: |
| color_corrections = setup_color_correction( |
| Image.fromarray(x_samples_np[0])) |
| global_state.color_corrections = color_corrections |
| img_ = apply_color_correction(color_corrections, |
| Image.fromarray(img)) |
| img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 |
| x_samples, x_samples_np = generate_first_img( |
| img_, 1 - cfg.x0_strength) |
|
|
| global_state.first_result = x_samples |
| global_state.first_img = img |
|
|
| Image.fromarray(x_samples_np[0]).save( |
| os.path.join(cfg.first_dir, 'first.jpg')) |
|
|
| return x_samples_np[0] |
|
|
|
|
| @torch.no_grad() |
| def process2(*args): |
| global global_state |
| global global_video_path |
|
|
| if global_state.processing_state != ProcessingState.FIRST_IMG: |
| raise gr.Error('Please generate the first key image before generating' |
| ' all key images') |
|
|
| cfg = create_cfg(global_video_path, *args) |
| global_state.update_sd_model(cfg.sd_model, cfg.control_type) |
| global_state.update_detector(cfg.control_type, cfg.canny_low, |
| cfg.canny_high) |
| global_state.processing_state = ProcessingState.KEY_IMGS |
|
|
| |
| shutil.rmtree(cfg.key_dir) |
| os.makedirs(cfg.key_dir, exist_ok=True) |
|
|
| ddim_v_sampler = global_state.ddim_v_sampler |
| model = ddim_v_sampler.model |
| detector = global_state.detector |
| controller = global_state.controller |
| flow_model = global_state.flow_model |
| model.control_scales = [cfg.control_strength] * 13 |
|
|
| num_samples = 1 |
| eta = 0.0 |
| firstx0 = True |
| pixelfusion = cfg.use_mask |
| imgs = sorted(os.listdir(cfg.input_dir)) |
| imgs = [os.path.join(cfg.input_dir, img) for img in imgs] |
|
|
| first_result = global_state.first_result |
| first_img = global_state.first_img |
| pre_result = first_result |
| pre_img = first_img |
|
|
| for i in range(0, cfg.frame_count - 1, cfg.interval): |
| cid = i + 1 |
| frame = cv2.imread(imgs[i + 1]) |
| print(cid) |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| img = HWC3(frame) |
| H, W, C = img.shape |
|
|
| if cfg.color_preserve or global_state.color_corrections is None: |
| img_ = numpy2tensor(img) |
| else: |
| img_ = apply_color_correction(global_state.color_corrections, |
| Image.fromarray(img)) |
| img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 |
| encoder_posterior = model.encode_first_stage(img_.to(device)) |
| x0 = model.get_first_stage_encoding(encoder_posterior).detach() |
|
|
| detected_map = detector(img) |
| detected_map = HWC3(detected_map) |
|
|
| control = torch.from_numpy( |
| detected_map.copy()).float().to(device) / 255.0 |
| control = torch.stack([control for _ in range(num_samples)], dim=0) |
| control = einops.rearrange(control, 'b h w c -> b c h w').clone() |
| cond = { |
| 'c_concat': [control], |
| 'c_crossattn': [ |
| model.get_learned_conditioning( |
| [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) |
| ] |
| } |
| un_cond = { |
| 'c_concat': [control], |
| 'c_crossattn': |
| [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] |
| } |
| shape = (4, H // 8, W // 8) |
|
|
| cond['c_concat'] = [control] |
| un_cond['c_concat'] = [control] |
|
|
| image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float() |
| image2 = torch.from_numpy(img).permute(2, 0, 1).float() |
| warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask( |
| flow_model, image1, image2, pre_result, False) |
| blend_mask_pre = blur( |
| F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4)) |
| blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1) |
|
|
| image1 = torch.from_numpy(first_img).permute(2, 0, 1).float() |
| warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( |
| flow_model, image1, image2, first_result, False) |
| blend_mask_0 = blur( |
| F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4)) |
| blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1) |
|
|
| if firstx0: |
| mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8) |
| controller.set_warp( |
| F.interpolate(bwd_flow_0 / 8.0, |
| scale_factor=1. / 8, |
| mode='bilinear'), mask) |
| else: |
| mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8) |
| controller.set_warp( |
| F.interpolate(bwd_flow_pre / 8.0, |
| scale_factor=1. / 8, |
| mode='bilinear'), mask) |
|
|
| controller.set_task('keepx0, keepstyle') |
| seed_everything(cfg.seed) |
| samples, intermediates = ddim_v_sampler.sample( |
| cfg.ddim_steps, |
| num_samples, |
| shape, |
| cond, |
| verbose=False, |
| eta=eta, |
| unconditional_guidance_scale=cfg.scale, |
| unconditional_conditioning=un_cond, |
| controller=controller, |
| x0=x0, |
| strength=1 - cfg.x0_strength) |
| direct_result = model.decode_first_stage(samples) |
|
|
| if not pixelfusion: |
| pre_result = direct_result |
| pre_img = img |
| viz = ( |
| einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 + |
| 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
|
|
| else: |
|
|
| blend_results = (1 - blend_mask_pre |
| ) * warped_pre + blend_mask_pre * direct_result |
| blend_results = ( |
| 1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results |
|
|
| bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1) |
| blend_mask = blur( |
| F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) |
| blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1) |
|
|
| encoder_posterior = model.encode_first_stage(blend_results) |
| xtrg = model.get_first_stage_encoding( |
| encoder_posterior).detach() |
| blend_results_rec = model.decode_first_stage(xtrg) |
| encoder_posterior = model.encode_first_stage(blend_results_rec) |
| xtrg_rec = model.get_first_stage_encoding( |
| encoder_posterior).detach() |
| xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec)) |
| blend_results_rec_new = model.decode_first_stage(xtrg_) |
| tmp = (abs(blend_results_rec_new - blend_results).mean( |
| dim=1, keepdims=True) > 0.25).float() |
| mask_x = F.max_pool2d((F.interpolate(tmp, |
| scale_factor=1 / 8., |
| mode='bilinear') > 0).float(), |
| kernel_size=3, |
| stride=1, |
| padding=1) |
|
|
| mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8) |
| ) |
|
|
| if cfg.smooth_boundary: |
| noise_rescale = find_flat_region(mask) |
| else: |
| noise_rescale = torch.ones_like(mask) |
| masks = [] |
| for i in range(cfg.ddim_steps): |
| if i <= cfg.ddim_steps * cfg.mask_period[ |
| 0] or i >= cfg.ddim_steps * cfg.mask_period[1]: |
| masks += [None] |
| else: |
| masks += [mask * cfg.mask_strength] |
|
|
| |
| |
| |
| |
| |
| xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask |
|
|
| tasks = 'keepstyle, keepx0' |
| if not firstx0: |
| tasks += ', updatex0' |
| if i % cfg.style_update_freq == 0: |
| tasks += ', updatestyle' |
| controller.set_task(tasks, 1.0) |
|
|
| seed_everything(cfg.seed) |
| samples, _ = ddim_v_sampler.sample( |
| cfg.ddim_steps, |
| num_samples, |
| shape, |
| cond, |
| verbose=False, |
| eta=eta, |
| unconditional_guidance_scale=cfg.scale, |
| unconditional_conditioning=un_cond, |
| controller=controller, |
| x0=x0, |
| strength=1 - cfg.x0_strength, |
| xtrg=xtrg, |
| mask=masks, |
| noise_rescale=noise_rescale) |
| x_samples = model.decode_first_stage(samples) |
| pre_result = x_samples |
| pre_img = img |
|
|
| viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + |
| 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
|
|
| Image.fromarray(viz[0]).save( |
| os.path.join(cfg.key_dir, f'{cid:04d}.png')) |
|
|
| key_video_path = os.path.join(cfg.work_dir, 'key.mp4') |
| fps = get_fps(cfg.input_path) |
| fps //= cfg.interval |
| frame_to_video(key_video_path, cfg.key_dir, fps, False) |
|
|
| return key_video_path |
|
|
|
|
| DESCRIPTION = ''' |
| ## [Rerender A Video](https://github.com/williamyang1991/Rerender_A_Video) |
| ### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper. |
| ### To avoid overload, we set limitations to the **maximum frame number** (8) and the maximum frame resolution (512x768). |
| ### The running time of a video of size 512x640 is about 1 minute per keyframe under T4 GPU. |
| ### How to use: |
| 1. **Run 1st Key Frame**: only translate the first frame, so you can adjust the prompts/models/parameters to find your ideal output appearance before run the whole video. |
| 2. **Run Key Frames**: translate all the key frames based on the settings of the first frame |
| 3. **Run All**: **Run 1st Key Frame** and **Run Key Frames** |
| 4. **Run Propagation**: propogate the key frames to other frames for full video translation. This function is supported [here](https://github.com/williamyang1991/Rerender_A_Video#webui-recommended) |
| ### Tips: |
| 1. This method cannot handle large or quick motions where the optical flow is hard to estimate. **Videos with stable motions are preferred**. |
| 2. Pixel-aware fusion may not work for large or quick motions. |
| 3. Try different color-aware AdaIN settings and even unuse it to avoid color jittering. |
| 4. `revAnimated_v11` model for non-photorealstic style, `realisticVisionV20_v20` model for photorealstic style. |
| 5. To use your own SD/LoRA model, you may clone the space and specify your model with [sd_model_cfg.py](https://huggingface.co/spaces/Anonymous-sub/Rerender/blob/main/sd_model_cfg.py). |
| 6. This method is based on the original SD model. You may need to [convert](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py) Diffuser/Automatic1111 models to the original one. |
| |
| **This code is for research purpose and non-commercial use only.** |
| |
| [](https://huggingface.co/spaces/Anonymous-sub/Rerender?duplicate=true) for no queue on your own hardware. |
| ''' |
|
|
|
|
| ARTICLE = r""" |
| If Rerender-A-Video is helpful, please help to ⭐ the <a href='https://github.com/williamyang1991/Rerender_A_Video' target='_blank'>Github Repo</a>. Thanks! |
| [](https://github.com/williamyang1991/Rerender_A_Video) |
| --- |
| 📝 **Citation** |
| If our work is useful for your research, please consider citing: |
| ```bibtex |
| @inproceedings{yang2023rerender, |
| title = {Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation}, |
| author = {Yang, Shuai and Zhou, Yifan and Liu, Ziwei and and Loy, Chen Change}, |
| booktitle = {ACM SIGGRAPH Asia Conference Proceedings}, |
| year = {2023}, |
| } |
| ``` |
| 📋 **License** |
| This project is licensed under <a rel="license" href="https://github.com/williamyang1991/Rerender_A_Video/blob/main/LICENSE.md">S-Lab License 1.0</a>. |
| Redistribution and use for non-commercial purposes should follow this license. |
| |
| 📧 **Contact** |
| If you have any questions, please feel free to reach me out at <b>williamyang@pku.edu.cn</b>. |
| """ |
|
|
| FOOTER = '<div align=center><img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.laobi.icu/badge?page_id=williamyang1991/Rerender_A_Video" /></div>' |
|
|
|
|
| block = gr.Blocks().queue() |
| with block: |
| with gr.Row(): |
| gr.Markdown(DESCRIPTION) |
| with gr.Row(): |
| with gr.Column(): |
| input_path = gr.Video(label='Input Video', |
| source='upload', |
| format='mp4', |
| visible=True) |
| prompt = gr.Textbox(label='Prompt') |
| seed = gr.Slider(label='Seed', |
| minimum=0, |
| maximum=2147483647, |
| step=1, |
| value=0, |
| randomize=True) |
| run_button = gr.Button(value='Run All') |
| with gr.Row(): |
| run_button1 = gr.Button(value='Run 1st Key Frame') |
| run_button2 = gr.Button(value='Run Key Frames') |
| run_button3 = gr.Button(value='Run Propagation') |
| with gr.Accordion('Advanced options for the 1st frame translation', |
| open=False): |
| image_resolution = gr.Slider( |
| label='Frame rsolution', |
| minimum=256, |
| maximum=512, |
| value=512, |
| step=64, |
| info='To avoid overload, maximum 512') |
| control_strength = gr.Slider(label='ControNet strength', |
| minimum=0.0, |
| maximum=2.0, |
| value=1.0, |
| step=0.01) |
| x0_strength = gr.Slider( |
| label='Denoising strength', |
| minimum=0.00, |
| maximum=1.05, |
| value=0.75, |
| step=0.05, |
| info=('0: fully recover the input.' |
| '1.05: fully rerender the input.')) |
| color_preserve = gr.Checkbox( |
| label='Preserve color', |
| value=True, |
| info='Keep the color of the input video') |
| with gr.Row(): |
| left_crop = gr.Slider(label='Left crop length', |
| minimum=0, |
| maximum=512, |
| value=0, |
| step=1) |
| right_crop = gr.Slider(label='Right crop length', |
| minimum=0, |
| maximum=512, |
| value=0, |
| step=1) |
| with gr.Row(): |
| top_crop = gr.Slider(label='Top crop length', |
| minimum=0, |
| maximum=512, |
| value=0, |
| step=1) |
| bottom_crop = gr.Slider(label='Bottom crop length', |
| minimum=0, |
| maximum=512, |
| value=0, |
| step=1) |
| with gr.Row(): |
| control_type = gr.Dropdown(['HED', 'canny', 'depth'], |
| label='Control type', |
| value='HED') |
| low_threshold = gr.Slider(label='Canny low threshold', |
| minimum=1, |
| maximum=255, |
| value=100, |
| step=1) |
| high_threshold = gr.Slider(label='Canny high threshold', |
| minimum=1, |
| maximum=255, |
| value=200, |
| step=1) |
| ddim_steps = gr.Slider(label='Steps', |
| minimum=1, |
| maximum=20, |
| value=20, |
| step=1, |
| info='To avoid overload, maximum 20') |
| scale = gr.Slider(label='CFG scale', |
| minimum=0.1, |
| maximum=30.0, |
| value=7.5, |
| step=0.1) |
| sd_model_list = list(model_dict.keys()) |
| sd_model = gr.Dropdown(sd_model_list, |
| label='Base model', |
| value='Stable Diffusion 1.5') |
| a_prompt = gr.Textbox(label='Added prompt', |
| value='best quality, extremely detailed') |
| n_prompt = gr.Textbox( |
| label='Negative prompt', |
| value=('longbody, lowres, bad anatomy, bad hands, ' |
| 'missing fingers, extra digit, fewer digits, ' |
| 'cropped, worst quality, low quality')) |
| with gr.Accordion('Advanced options for the key fame translation', |
| open=False): |
| interval = gr.Slider( |
| label='Key frame frequency (K)', |
| minimum=1, |
| maximum=MAX_KEYFRAME, |
| value=1, |
| step=1, |
| info='Uniformly sample the key frames every K frames') |
| keyframe_count = gr.Slider( |
| label='Number of key frames', |
| minimum=1, |
| maximum=MAX_KEYFRAME, |
| value=1, |
| step=1, |
| info='To avoid overload, maximum 8 key frames') |
|
|
| use_constraints = gr.CheckboxGroup( |
| [ |
| 'shape-aware fusion', 'pixel-aware fusion', |
| 'color-aware AdaIN' |
| ], |
| label='Select the cross-frame contraints to be used', |
| value=[ |
| 'shape-aware fusion', 'pixel-aware fusion', |
| 'color-aware AdaIN' |
| ]), |
| with gr.Row(): |
| cross_start = gr.Slider( |
| label='Cross-frame attention start', |
| minimum=0, |
| maximum=1, |
| value=0, |
| step=0.05) |
| cross_end = gr.Slider(label='Cross-frame attention end', |
| minimum=0, |
| maximum=1, |
| value=1, |
| step=0.05) |
| style_update_freq = gr.Slider( |
| label='Cross-frame attention update frequency', |
| minimum=1, |
| maximum=100, |
| value=1, |
| step=1, |
| info=('Update the key and value for ' |
| 'cross-frame attention every N key frames (recommend N*K>=10)' |
| )) |
| with gr.Row(): |
| warp_start = gr.Slider(label='Shape-aware fusion start', |
| minimum=0, |
| maximum=1, |
| value=0, |
| step=0.05) |
| warp_end = gr.Slider(label='Shape-aware fusion end', |
| minimum=0, |
| maximum=1, |
| value=0.1, |
| step=0.05) |
| with gr.Row(): |
| mask_start = gr.Slider(label='Pixel-aware fusion start', |
| minimum=0, |
| maximum=1, |
| value=0.5, |
| step=0.05) |
| mask_end = gr.Slider(label='Pixel-aware fusion end', |
| minimum=0, |
| maximum=1, |
| value=0.8, |
| step=0.05) |
| with gr.Row(): |
| ada_start = gr.Slider(label='Color-aware AdaIN start', |
| minimum=0, |
| maximum=1, |
| value=0.8, |
| step=0.05) |
| ada_end = gr.Slider(label='Color-aware AdaIN end', |
| minimum=0, |
| maximum=1, |
| value=1, |
| step=0.05) |
| mask_strength = gr.Slider(label='Pixel-aware fusion stength', |
| minimum=0, |
| maximum=1, |
| value=0.5, |
| step=0.01) |
| inner_strength = gr.Slider( |
| label='Pixel-aware fusion detail level', |
| minimum=0.5, |
| maximum=1, |
| value=0.9, |
| step=0.01, |
| info='Use a low value to prevent artifacts') |
| smooth_boundary = gr.Checkbox( |
| label='Smooth fusion boundary', |
| value=True, |
| info='Select to prevent artifacts at boundary') |
|
|
| with gr.Accordion('Example configs', open=True): |
| config_dir = 'config' |
| config_list = os.listdir(config_dir) |
| args_list = [] |
| for config in config_list: |
| try: |
| config_path = os.path.join(config_dir, config) |
| args = cfg_to_input(config_path) |
| args_list.append(args) |
| except FileNotFoundError: |
| |
| pass |
|
|
| ips = [ |
| prompt, image_resolution, control_strength, color_preserve, |
| left_crop, right_crop, top_crop, bottom_crop, control_type, |
| low_threshold, high_threshold, ddim_steps, scale, seed, |
| sd_model, a_prompt, n_prompt, interval, keyframe_count, |
| x0_strength, use_constraints[0], cross_start, cross_end, |
| style_update_freq, warp_start, warp_end, mask_start, |
| mask_end, ada_start, ada_end, mask_strength, |
| inner_strength, smooth_boundary |
| ] |
|
|
| with gr.Column(): |
| result_image = gr.Image(label='Output first frame', |
| type='numpy', |
| interactive=False) |
| result_keyframe = gr.Video(label='Output key frame video', |
| format='mp4', |
| interactive=False) |
| with gr.Row(): |
| gr.Examples(examples=args_list, |
| inputs=[input_path, *ips], |
| fn=process0, |
| outputs=[result_image, result_keyframe], |
| cache_examples=True) |
|
|
| gr.Markdown(ARTICLE) |
| gr.Markdown(FOOTER) |
|
|
| def input_uploaded(path): |
| frame_count = get_frame_count(path) |
| if frame_count <= 2: |
| raise gr.Error('The input video is too short!' |
| 'Please input another video.') |
|
|
| default_interval = min(10, frame_count - 2) |
| max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) |
|
|
| global video_frame_count |
| video_frame_count = frame_count |
| global global_video_path |
| global_video_path = path |
|
|
| return gr.Slider.update(value=default_interval, |
| maximum=frame_count - 2), gr.Slider.update( |
| value=max_keyframe, maximum=max_keyframe) |
|
|
| def input_changed(path): |
| frame_count = get_frame_count(path) |
| if frame_count <= 2: |
| return gr.Slider.update(maximum=1), gr.Slider.update(maximum=1) |
|
|
| default_interval = min(10, frame_count - 2) |
| max_keyframe = min((frame_count - 2) // default_interval, MAX_KEYFRAME) |
|
|
| global video_frame_count |
| video_frame_count = frame_count |
| global global_video_path |
| global_video_path = path |
|
|
| return gr.Slider.update(value=default_interval, |
| maximum=frame_count - 2), \ |
| gr.Slider.update(maximum=max_keyframe) |
|
|
| def interval_changed(interval): |
| global video_frame_count |
| if video_frame_count is None: |
| return gr.Slider.update() |
|
|
| max_keyframe = min((video_frame_count - 2) // interval, MAX_KEYFRAME) |
|
|
| return gr.Slider.update(value=max_keyframe, maximum=max_keyframe) |
|
|
| input_path.change(input_changed, input_path, [interval, keyframe_count]) |
| input_path.upload(input_uploaded, input_path, [interval, keyframe_count]) |
| interval.change(interval_changed, interval, keyframe_count) |
|
|
| run_button.click(fn=process, |
| inputs=ips, |
| outputs=[result_image, result_keyframe]) |
| run_button1.click(fn=process1, inputs=ips, outputs=[result_image]) |
| run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe]) |
|
|
| def process3(): |
| raise gr.Error( |
| "Coming Soon. Full code for full video translation will be " |
| "released upon the publication of the paper.") |
|
|
| run_button3.click(fn=process3, outputs=[result_keyframe]) |
|
|
| block.queue(concurrency_count=1, max_size=20) |
| block.launch(server_name='0.0.0.0') |
|
|