| | import os |
| | import shutil |
| | from enum import Enum |
| |
|
| | import cv2 |
| | import einops |
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import torchvision.transforms as T |
| | from blendmodes.blend import BlendType, blendLayers |
| | from PIL import Image |
| | from pytorch_lightning import seed_everything |
| | from safetensors.torch import load_file |
| | from skimage import exposure |
| |
|
| | import src.import_util |
| | from ControlNet.annotator.canny import CannyDetector |
| | from ControlNet.annotator.hed import HEDdetector |
| | from ControlNet.annotator.util import HWC3 |
| | from ControlNet.cldm.model import create_model, load_state_dict |
| | from gmflow_module.gmflow.gmflow import GMFlow |
| | from flow.flow_utils import get_warped_and_mask |
| | from src.config import RerenderConfig |
| | from src.controller import AttentionControl |
| | from src.ddim_v_hacked import DDIMVSampler |
| | from src.img_util import find_flat_region, numpy2tensor |
| | from src.video_util import (frame_to_video, get_fps, get_frame_count, |
| | prepare_frames) |
| |
|
| | import huggingface_hub |
| |
|
| | repo_name = 'Anonymous-sub/Rerender' |
| |
|
| | huggingface_hub.hf_hub_download(repo_name, |
| | 'pexels-koolshooters-7322716.mp4', |
| | local_dir='videos') |
| | huggingface_hub.hf_hub_download( |
| | repo_name, |
| | 'pexels-antoni-shkraba-8048492-540x960-25fps.mp4', |
| | local_dir='videos') |
| | huggingface_hub.hf_hub_download( |
| | repo_name, |
| | 'pexels-cottonbro-studio-6649832-960x506-25fps.mp4', |
| | local_dir='videos') |
| |
|
| | inversed_model_dict = {v: k for k, v in model_dict.items()} |
| |
|
| | to_tensor = T.PILToTensor() |
| | blur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18)) |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
|
| |
|
| | class ProcessingState(Enum): |
| | NULL = 0 |
| | FIRST_IMG = 1 |
| | KEY_IMGS = 2 |
| |
|
| |
|
| | class GlobalState: |
| |
|
| | def __init__(self): |
| | self.sd_model = None |
| | self.ddim_v_sampler = None |
| | self.detector_type = None |
| | self.detector = None |
| | self.controller = None |
| | self.processing_state = ProcessingState.NULL |
| | flow_model = GMFlow( |
| | feature_channels=128, |
| | num_scales=1, |
| | upsample_factor=8, |
| | num_head=1, |
| | attention_type='swin', |
| | ffn_dim_expansion=4, |
| | num_transformer_layers=6, |
| | ).to(device) |
| |
|
| | checkpoint = torch.load('models/gmflow_sintel-0c07dcb3.pth', |
| | map_location=lambda storage, loc: storage) |
| | weights = checkpoint['model'] if 'model' in checkpoint else checkpoint |
| | flow_model.load_state_dict(weights, strict=False) |
| | flow_model.eval() |
| | self.flow_model = flow_model |
| |
|
| | def update_controller(self, inner_strength, mask_period, cross_period, |
| | ada_period, warp_period): |
| | self.controller = AttentionControl(inner_strength, mask_period, |
| | cross_period, ada_period, |
| | warp_period) |
| |
|
| | def update_sd_model(self, sd_model, control_type): |
| | if sd_model == self.sd_model: |
| | return |
| | self.sd_model = sd_model |
| | model = create_model('./ControlNet/models/cldm_v15.yaml').cpu() |
| | if control_type == 'HED': |
| | model.load_state_dict( |
| | load_state_dict(huggingface_hub.hf_hub_download( |
| | 'lllyasviel/ControlNet', './models/control_sd15_hed.pth'), |
| | location=device)) |
| | elif control_type == 'canny': |
| | model.load_state_dict( |
| | load_state_dict(huggingface_hub.hf_hub_download( |
| | 'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'), |
| | location=device)) |
| | model.to(device) |
| | sd_model_path = model_dict[sd_model] |
| | if len(sd_model_path) > 0: |
| | model_ext = os.path.splitext(sd_model_path)[1] |
| | downloaded_model = huggingface_hub.hf_hub_download( |
| | repo_name, sd_model_path) |
| | if model_ext == '.safetensors': |
| | model.load_state_dict(load_file(downloaded_model), |
| | strict=False) |
| | elif model_ext == '.ckpt' or model_ext == '.pth': |
| | model.load_state_dict( |
| | torch.load(downloaded_model)['state_dict'], strict=False) |
| |
|
| | try: |
| | model.first_stage_model.load_state_dict(torch.load( |
| | huggingface_hub.hf_hub_download( |
| | 'stabilityai/sd-vae-ft-mse-original', |
| | 'vae-ft-mse-840000-ema-pruned.ckpt'))['state_dict'], |
| | strict=False) |
| | except Exception: |
| | print('Warning: We suggest you download the fine-tuned VAE', |
| | 'otherwise the generation quality will be degraded') |
| |
|
| | self.ddim_v_sampler = DDIMVSampler(model) |
| |
|
| | def clear_sd_model(self): |
| | self.sd_model = None |
| | self.ddim_v_sampler = None |
| | if device == 'cuda': |
| | torch.cuda.empty_cache() |
| |
|
| | def update_detector(self, control_type, canny_low=100, canny_high=200): |
| | if self.detector_type == control_type: |
| | return |
| | if control_type == 'HED': |
| | self.detector = HEDdetector() |
| | elif control_type == 'canny': |
| | canny_detector = CannyDetector() |
| | low_threshold = canny_low |
| | high_threshold = canny_high |
| |
|
| | def apply_canny(x): |
| | return canny_detector(x, low_threshold, high_threshold) |
| |
|
| | self.detector = apply_canny |
| |
|
| |
|
| | global_state = GlobalState() |
| | global_video_path = None |
| | video_frame_count = None |
| |
|
| |
|
| | def create_cfg(input_path, prompt, image_resolution, control_strength, |
| | color_preserve, left_crop, right_crop, top_crop, bottom_crop, |
| | control_type, low_threshold, high_threshold, ddim_steps, scale, |
| | seed, sd_model, a_prompt, n_prompt, interval, keyframe_count, |
| | x0_strength, use_constraints, cross_start, cross_end, |
| | style_update_freq, warp_start, warp_end, mask_start, mask_end, |
| | ada_start, ada_end, mask_strength, inner_strength, |
| | smooth_boundary): |
| | use_warp = 'shape-aware fusion' in use_constraints |
| | use_mask = 'pixel-aware fusion' in use_constraints |
| | use_ada = 'color-aware AdaIN' in use_constraints |
| |
|
| | if not use_warp: |
| | warp_start = 1 |
| | warp_end = 0 |
| |
|
| | if not use_mask: |
| | mask_start = 1 |
| | mask_end = 0 |
| |
|
| | if not use_ada: |
| | ada_start = 1 |
| | ada_end = 0 |
| |
|
| | input_name = os.path.split(input_path)[-1].split('.')[0] |
| | frame_count = 2 + keyframe_count * interval |
| | cfg = RerenderConfig() |
| | cfg.create_from_parameters( |
| | input_path, |
| | os.path.join('result', input_name, 'blend.mp4'), |
| | prompt, |
| | a_prompt=a_prompt, |
| | n_prompt=n_prompt, |
| | frame_count=frame_count, |
| | interval=interval, |
| | crop=[left_crop, right_crop, top_crop, bottom_crop], |
| | sd_model=sd_model, |
| | ddim_steps=ddim_steps, |
| | scale=scale, |
| | control_type=control_type, |
| | control_strength=control_strength, |
| | canny_low=low_threshold, |
| | canny_high=high_threshold, |
| | seed=seed, |
| | image_resolution=image_resolution, |
| | x0_strength=x0_strength, |
| | style_update_freq=style_update_freq, |
| | cross_period=(cross_start, cross_end), |
| | warp_period=(warp_start, warp_end), |
| | mask_period=(mask_start, mask_end), |
| | ada_period=(ada_start, ada_end), |
| | mask_strength=mask_strength, |
| | inner_strength=inner_strength, |
| | smooth_boundary=smooth_boundary, |
| | color_preserve=color_preserve) |
| | return cfg |
| |
|
| |
|
| | def cfg_to_input(filename): |
| |
|
| | cfg = RerenderConfig() |
| | cfg.create_from_path(filename) |
| | keyframe_count = (cfg.frame_count - 2) // cfg.interval |
| | use_constraints = [ |
| | 'shape-aware fusion', 'pixel-aware fusion', 'color-aware AdaIN' |
| | ] |
| |
|
| | sd_model = inversed_model_dict.get(cfg.sd_model, 'Stable Diffusion 1.5') |
| |
|
| | args = [ |
| | cfg.input_path, cfg.prompt, cfg.image_resolution, cfg.control_strength, |
| | cfg.color_preserve, *cfg.crop, cfg.control_type, cfg.canny_low, |
| | cfg.canny_high, cfg.ddim_steps, cfg.scale, cfg.seed, sd_model, |
| | cfg.a_prompt, cfg.n_prompt, cfg.interval, keyframe_count, |
| | cfg.x0_strength, use_constraints, *cfg.cross_period, |
| | cfg.style_update_freq, *cfg.warp_period, *cfg.mask_period, |
| | *cfg.ada_period, cfg.mask_strength, cfg.inner_strength, |
| | cfg.smooth_boundary |
| | ] |
| | return args |
| |
|
| |
|
| | def setup_color_correction(image): |
| | correction_target = cv2.cvtColor(np.asarray(image.copy()), |
| | cv2.COLOR_RGB2LAB) |
| | return correction_target |
| |
|
| |
|
| | def apply_color_correction(correction, original_image): |
| | image = Image.fromarray( |
| | cv2.cvtColor( |
| | exposure.match_histograms(cv2.cvtColor(np.asarray(original_image), |
| | cv2.COLOR_RGB2LAB), |
| | correction, |
| | channel_axis=2), |
| | cv2.COLOR_LAB2RGB).astype('uint8')) |
| |
|
| | image = blendLayers(image, original_image, BlendType.LUMINOSITY) |
| |
|
| | return image |
| |
|
| |
|
| | @torch.no_grad() |
| | def process(*args): |
| | first_frame = process1(*args) |
| |
|
| | keypath = process2(*args) |
| |
|
| | return first_frame, keypath |
| |
|
| |
|
| | @torch.no_grad() |
| | def process0(*args): |
| | global global_video_path |
| | global_video_path = args[0] |
| | return process(*args[1:]) |
| |
|
| |
|
| | @torch.no_grad() |
| | def process1(*args): |
| |
|
| | global global_video_path |
| | cfg = create_cfg(global_video_path, *args) |
| | global global_state |
| | global_state.update_sd_model(cfg.sd_model, cfg.control_type) |
| | global_state.update_controller(cfg.inner_strength, cfg.mask_period, |
| | cfg.cross_period, cfg.ada_period, |
| | cfg.warp_period) |
| | global_state.update_detector(cfg.control_type, cfg.canny_low, |
| | cfg.canny_high) |
| | global_state.processing_state = ProcessingState.FIRST_IMG |
| |
|
| | prepare_frames(cfg.input_path, cfg.input_dir, cfg.image_resolution, |
| | cfg.crop) |
| |
|
| | ddim_v_sampler = global_state.ddim_v_sampler |
| | model = ddim_v_sampler.model |
| | detector = global_state.detector |
| | controller = global_state.controller |
| | model.control_scales = [cfg.control_strength] * 13 |
| | model.to(device) |
| |
|
| | num_samples = 1 |
| | eta = 0.0 |
| | imgs = sorted(os.listdir(cfg.input_dir)) |
| | imgs = [os.path.join(cfg.input_dir, img) for img in imgs] |
| |
|
| | model.cond_stage_model.device = device |
| |
|
| | with torch.no_grad(): |
| | frame = cv2.imread(imgs[0]) |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | img = HWC3(frame) |
| | H, W, C = img.shape |
| |
|
| | img_ = numpy2tensor(img) |
| |
|
| | def generate_first_img(img_, strength): |
| | encoder_posterior = model.encode_first_stage(img_.to(device)) |
| | x0 = model.get_first_stage_encoding(encoder_posterior).detach() |
| |
|
| | detected_map = detector(img) |
| | detected_map = HWC3(detected_map) |
| |
|
| | control = torch.from_numpy( |
| | detected_map.copy()).float().to(device) / 255.0 |
| | control = torch.stack([control for _ in range(num_samples)], dim=0) |
| | control = einops.rearrange(control, 'b h w c -> b c h w').clone() |
| | cond = { |
| | 'c_concat': [control], |
| | 'c_crossattn': [ |
| | model.get_learned_conditioning( |
| | [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) |
| | ] |
| | } |
| | un_cond = { |
| | 'c_concat': [control], |
| | 'c_crossattn': |
| | [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] |
| | } |
| | shape = (4, H // 8, W // 8) |
| |
|
| | controller.set_task('initfirst') |
| | seed_everything(cfg.seed) |
| |
|
| | samples, _ = ddim_v_sampler.sample( |
| | cfg.ddim_steps, |
| | num_samples, |
| | shape, |
| | cond, |
| | verbose=False, |
| | eta=eta, |
| | unconditional_guidance_scale=cfg.scale, |
| | unconditional_conditioning=un_cond, |
| | controller=controller, |
| | x0=x0, |
| | strength=strength) |
| | x_samples = model.decode_first_stage(samples) |
| | x_samples_np = ( |
| | einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + |
| | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
| | return x_samples, x_samples_np |
| |
|
| | |
| | |
| | if not cfg.color_preserve: |
| | first_strength = -1 |
| | else: |
| | first_strength = 1 - cfg.x0_strength |
| |
|
| | x_samples, x_samples_np = generate_first_img(img_, first_strength) |
| |
|
| | if not cfg.color_preserve: |
| | color_corrections = setup_color_correction( |
| | Image.fromarray(x_samples_np[0])) |
| | global_state.color_corrections = color_corrections |
| | img_ = apply_color_correction(color_corrections, |
| | Image.fromarray(img)) |
| | img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 |
| | x_samples, x_samples_np = generate_first_img( |
| | img_, 1 - cfg.x0_strength) |
| |
|
| | global_state.first_result = x_samples |
| | global_state.first_img = img |
| |
|
| | Image.fromarray(x_samples_np[0]).save( |
| | os.path.join(cfg.first_dir, 'first.jpg')) |
| |
|
| | return x_samples_np[0] |
| |
|
| |
|
| | @torch.no_grad() |
| | def process2(*args): |
| | global global_state |
| | global global_video_path |
| |
|
| | if global_state.processing_state != ProcessingState.FIRST_IMG: |
| | raise gr.Error('Please generate the first key image before generating' |
| | ' all key images') |
| |
|
| | cfg = create_cfg(global_video_path, *args) |
| | global_state.update_sd_model(cfg.sd_model, cfg.control_type) |
| | global_state.update_detector(cfg.control_type, cfg.canny_low, |
| | cfg.canny_high) |
| | global_state.processing_state = ProcessingState.KEY_IMGS |
| |
|
| | |
| | shutil.rmtree(cfg.key_dir) |
| | os.makedirs(cfg.key_dir, exist_ok=True) |
| |
|
| | ddim_v_sampler = global_state.ddim_v_sampler |
| | model = ddim_v_sampler.model |
| | detector = global_state.detector |
| | controller = global_state.controller |
| | flow_model = global_state.flow_model |
| | model.control_scales = [cfg.control_strength] * 13 |
| |
|
| | num_samples = 1 |
| | eta = 0.0 |
| | firstx0 = True |
| | pixelfusion = cfg.use_mask |
| | imgs = sorted(os.listdir(cfg.input_dir)) |
| | imgs = [os.path.join(cfg.input_dir, img) for img in imgs] |
| |
|
| | first_result = global_state.first_result |
| | first_img = global_state.first_img |
| | pre_result = first_result |
| | pre_img = first_img |
| |
|
| | for i in range(0, cfg.frame_count - 1, cfg.interval): |
| | cid = i + 1 |
| | frame = cv2.imread(imgs[i + 1]) |
| | print(cid) |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | img = HWC3(frame) |
| | H, W, C = img.shape |
| |
|
| | if cfg.color_preserve or global_state.color_corrections is None: |
| | img_ = numpy2tensor(img) |
| | else: |
| | img_ = apply_color_correction(global_state.color_corrections, |
| | Image.fromarray(img)) |
| | img_ = to_tensor(img_).unsqueeze(0)[:, :3] / 127.5 - 1 |
| | encoder_posterior = model.encode_first_stage(img_.to(device)) |
| | x0 = model.get_first_stage_encoding(encoder_posterior).detach() |
| |
|
| | detected_map = detector(img) |
| | detected_map = HWC3(detected_map) |
| |
|
| | control = torch.from_numpy( |
| | detected_map.copy()).float().to(device) / 255.0 |
| | control = torch.stack([control for _ in range(num_samples)], dim=0) |
| | control = einops.rearrange(control, 'b h w c -> b c h w').clone() |
| | cond = { |
| | 'c_concat': [control], |
| | 'c_crossattn': [ |
| | model.get_learned_conditioning( |
| | [cfg.prompt + ', ' + cfg.a_prompt] * num_samples) |
| | ] |
| | } |
| | un_cond = { |
| | 'c_concat': [control], |
| | 'c_crossattn': |
| | [model.get_learned_conditioning([cfg.n_prompt] * num_samples)] |
| | } |
| | shape = (4, H // 8, W // 8) |
| |
|
| | cond['c_concat'] = [control] |
| | un_cond['c_concat'] = [control] |
| |
|
| | image1 = torch.from_numpy(pre_img).permute(2, 0, 1).float() |
| | image2 = torch.from_numpy(img).permute(2, 0, 1).float() |
| | warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask( |
| | flow_model, image1, image2, pre_result, False) |
| | blend_mask_pre = blur( |
| | F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4)) |
| | blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1) |
| |
|
| | image1 = torch.from_numpy(first_img).permute(2, 0, 1).float() |
| | warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( |
| | flow_model, image1, image2, first_result, False) |
| | blend_mask_0 = blur( |
| | F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4)) |
| | blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1) |
| |
|
| | if firstx0: |
| | mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8) |
| | controller.set_warp( |
| | F.interpolate(bwd_flow_0 / 8.0, |
| | scale_factor=1. / 8, |
| | mode='bilinear'), mask) |
| | else: |
| | mask = 1 - F.max_pool2d(blend_mask_pre, kernel_size=8) |
| | controller.set_warp( |
| | F.interpolate(bwd_flow_pre / 8.0, |
| | scale_factor=1. / 8, |
| | mode='bilinear'), mask) |
| |
|
| | controller.set_task('keepx0, keepstyle') |
| | seed_everything(cfg.seed) |
| | samples, intermediates = ddim_v_sampler.sample( |
| | cfg.ddim_steps, |
| | num_samples, |
| | shape, |
| | cond, |
| | verbose=False, |
| | eta=eta, |
| | unconditional_guidance_scale=cfg.scale, |
| | unconditional_conditioning=un_cond, |
| | controller=controller, |
| | x0=x0, |
| | strength=1 - cfg.x0_strength) |
| | direct_result = model.decode_first_stage(samples) |
| |
|
| | if not pixelfusion: |
| | pre_result = direct_result |
| | pre_img = img |
| | viz = ( |
| | einops.rearrange(direct_result, 'b c h w -> b h w c') * 127.5 + |
| | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
| |
|
| | else: |
| |
|
| | blend_results = (1 - blend_mask_pre |
| | ) * warped_pre + blend_mask_pre * direct_result |
| | blend_results = ( |
| | 1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results |
| |
|
| | bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1) |
| | blend_mask = blur( |
| | F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) |
| | blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1) |
| |
|
| | encoder_posterior = model.encode_first_stage(blend_results) |
| | xtrg = model.get_first_stage_encoding( |
| | encoder_posterior).detach() |
| | blend_results_rec = model.decode_first_stage(xtrg) |
| | encoder_posterior = model.encode_first_stage(blend_results_rec) |
| | xtrg_rec = model.get_first_stage_encoding( |
| | encoder_posterior).detach() |
| | xtrg_ = (xtrg + 1 * (xtrg - xtrg_rec)) |
| | blend_results_rec_new = model.decode_first_stage(xtrg_) |
| | tmp = (abs(blend_results_rec_new - blend_results).mean( |
| | dim=1, keepdims=True) > 0.25).float() |
| | mask_x = F.max_pool2d((F.interpolate(tmp, |
| | scale_factor=1 / 8., |
| | mode='bilinear') > 0).float(), |
| | kernel_size=3, |
| | stride=1, |
| | padding=1) |
| |
|
| | mask = (1 - F.max_pool2d(1 - blend_mask, kernel_size=8) |
| | ) |
| |
|
| | if cfg.smooth_boundary: |
| | noise_rescale = find_flat_region(mask) |
| | else: |
| | noise_rescale = torch.ones_like(mask) |
| | masks = [] |
| | for i in range(cfg.ddim_steps): |
| | if i <= cfg.ddim_steps * cfg.mask_period[ |
| | 0] or i >= cfg.ddim_steps * cfg.mask_period[1]: |
| | masks += [None] |
| | else: |
| | masks += [mask * cfg.mask_strength] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask |
| |
|
| | tasks = 'keepstyle, keepx0' |
| | if not firstx0: |
| | tasks += ', updatex0' |
| | if i % cfg.style_update_freq == 0: |
| | tasks += ', updatestyle' |
| | controller.set_task(tasks, 1.0) |
| |
|
| | seed_everything(cfg.seed) |
| | samples, _ = ddim_v_sampler.sample( |
| | cfg.ddim_steps, |
| | num_samples, |
| | shape, |
| | cond, |
| | verbose=False, |
| | eta=eta, |
| | unconditional_guidance_scale=cfg.scale, |
| | unconditional_conditioning=un_cond, |
| | controller=controller, |
| | x0=x0, |
| | strength=1 - cfg.x0_strength, |
| | xtrg=xtrg, |
| | mask=masks, |
| | noise_rescale=noise_rescale) |
| | x_samples = model.decode_first_stage(samples) |
| | pre_result = x_samples |
| | pre_img = img |
| |
|
| | viz = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + |
| | 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
| |
|
| | Image.fromarray(viz[0]).save( |
| | os.path.join(cfg.key_dir, f'{cid:04d}.png')) |
| |
|
| | key_video_path = os.path.join(cfg.work_dir, 'key.mp4') |
| | fps = get_fps(cfg.input_path) |
| | fps //= cfg.interval |
| | frame_to_video(key_video_path, cfg.key_dir, fps, False) |
| |
|
| | return key_video_path |