Spaces:
Runtime error
Runtime error
| import os | |
| #os.environ['CUDA_VISIBLE_DEVICES'] = "6" | |
| # In China, set this to use huggingface | |
| # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' | |
| import cv2 | |
| import io | |
| import gc | |
| import yaml | |
| import argparse | |
| import torch | |
| import torchvision | |
| import diffusers | |
| from diffusers import StableDiffusionPipeline, AutoencoderKL, DDPMScheduler, ControlNetModel | |
| from src.utils import * | |
| from src.keyframe_selection import get_keyframe_ind | |
| from src.diffusion_hacked import apply_FRESCO_attn, apply_FRESCO_opt, disable_FRESCO_opt | |
| from src.diffusion_hacked import get_flow_and_interframe_paras, get_intraframe_paras | |
| from src.pipe_FRESCO import inference | |
| def get_models(config): | |
| print('\n' + '=' * 100) | |
| print('creating models...') | |
| import sys | |
| sys.path.append("./src/ebsynth/deps/gmflow/") | |
| sys.path.append("./src/EGNet/") | |
| sys.path.append("./src/ControlNet/") | |
| from gmflow.gmflow import GMFlow | |
| from model import build_model | |
| from annotator.hed import HEDdetector | |
| from annotator.canny import CannyDetector | |
| from annotator.midas import MidasDetector | |
| # optical flow | |
| flow_model = GMFlow(feature_channels=128, | |
| num_scales=1, | |
| upsample_factor=8, | |
| num_head=1, | |
| attention_type='swin', | |
| ffn_dim_expansion=4, | |
| num_transformer_layers=6, | |
| ).to('cuda') | |
| checkpoint = torch.load(config['gmflow_path'], map_location=lambda storage, loc: storage) | |
| weights = checkpoint['model'] if 'model' in checkpoint else checkpoint | |
| flow_model.load_state_dict(weights, strict=False) | |
| flow_model.eval() | |
| print('create optical flow estimation model successfully!') | |
| # saliency detection | |
| sod_model = build_model('resnet') | |
| sod_model.load_state_dict(torch.load(config['sod_path'])) | |
| sod_model.to("cuda").eval() | |
| print('create saliency detection model successfully!') | |
| # controlnet | |
| if config['controlnet_type'] not in ['hed', 'depth', 'canny']: | |
| print('unsupported control type, set to hed') | |
| config['controlnet_type'] = 'hed' | |
| controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-"+config['controlnet_type'], | |
| torch_dtype=torch.float16) | |
| controlnet.to("cuda") | |
| if config['controlnet_type'] == 'depth': | |
| detector = MidasDetector() | |
| elif config['controlnet_type'] == 'canny': | |
| detector = CannyDetector() | |
| else: | |
| detector = HEDdetector() | |
| print('create controlnet model-' + config['controlnet_type'] + ' successfully!') | |
| # diffusion model | |
| vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) | |
| pipe = StableDiffusionPipeline.from_pretrained(config['sd_path'], vae=vae, torch_dtype=torch.float16) | |
| pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config) | |
| #noise_scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler") | |
| pipe.to("cuda") | |
| pipe.scheduler.set_timesteps(config['num_inference_steps'], device=pipe._execution_device) | |
| if config['use_freeu']: | |
| from src.free_lunch_utils import apply_freeu | |
| apply_freeu(pipe, b1=1.2, b2=1.5, s1=1.0, s2=1.0) | |
| frescoProc = apply_FRESCO_attn(pipe) | |
| frescoProc.controller.disable_controller() | |
| apply_FRESCO_opt(pipe) | |
| print('create diffusion model ' + config['sd_path'] + ' successfully!') | |
| for param in flow_model.parameters(): | |
| param.requires_grad = False | |
| for param in sod_model.parameters(): | |
| param.requires_grad = False | |
| for param in controlnet.parameters(): | |
| param.requires_grad = False | |
| for param in pipe.unet.parameters(): | |
| param.requires_grad = False | |
| return pipe, frescoProc, controlnet, detector, flow_model, sod_model | |
| def apply_control(x, detector, config): | |
| if config['controlnet_type'] == 'depth': | |
| detected_map, _ = detector(x) | |
| elif config['controlnet_type'] == 'canny': | |
| detected_map = detector(x, 50, 100) | |
| else: | |
| detected_map = detector(x) | |
| return detected_map | |
| def run_keyframe_translation(config): | |
| pipe, frescoProc, controlnet, detector, flow_model, sod_model = get_models(config) | |
| device = pipe._execution_device | |
| guidance_scale = 7.5 | |
| do_classifier_free_guidance = guidance_scale > 1 | |
| assert(do_classifier_free_guidance) | |
| timesteps = pipe.scheduler.timesteps | |
| cond_scale = [config['cond_scale']] * config['num_inference_steps'] | |
| dilate = Dilate(device=device) | |
| base_prompt = config['prompt'] | |
| if 'Realistic' in config['sd_path'] or 'realistic' in config['sd_path']: | |
| a_prompt = ', RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, ' | |
| n_prompt = '(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation' | |
| else: | |
| a_prompt = ', best quality, extremely detailed, ' | |
| n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing finger, extra digit, fewer digits, cropped, worst quality, low quality' | |
| print('\n' + '=' * 100) | |
| print('key frame selection for \"%s\"...'%(config['file_path'])) | |
| video_cap = cv2.VideoCapture(config['file_path']) | |
| frame_num = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # you can set extra_prompts for individual keyframe | |
| # for example, extra_prompts[38] = ', closed eyes' to specify the person frame38 closes the eyes | |
| extra_prompts = [''] * frame_num | |
| keys = get_keyframe_ind(config['file_path'], frame_num, config['mininterv'], config['maxinterv']) | |
| os.makedirs(config['save_path'], exist_ok=True) | |
| os.makedirs(config['save_path']+'keys', exist_ok=True) | |
| os.makedirs(config['save_path']+'video', exist_ok=True) | |
| sublists = [keys[i:i+config['batch_size']-2] for i in range(2, len(keys), config['batch_size']-2)] | |
| sublists[0].insert(0, keys[0]) | |
| sublists[0].insert(1, keys[1]) | |
| if len(sublists) > 1 and len(sublists[-1]) < 3: | |
| add_num = 3 - len(sublists[-1]) | |
| sublists[-1] = sublists[-2][-add_num:] + sublists[-1] | |
| sublists[-2] = sublists[-2][:-add_num] | |
| if not sublists[-2]: | |
| del sublists[-2] | |
| print('processing %d batches:\nkeyframe indexes'%(len(sublists)), sublists) | |
| print('\n' + '=' * 100) | |
| print('video to video translation...') | |
| batch_ind = 0 | |
| propagation_mode = batch_ind > 0 | |
| imgs = [] | |
| record_latents = [] | |
| video_cap = cv2.VideoCapture(config['file_path']) | |
| for i in range(frame_num): | |
| # prepare a batch of frame based on sublists | |
| success, frame = video_cap.read() | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| img = resize_image(frame, 512) | |
| H, W, C = img.shape | |
| Image.fromarray(img).save(os.path.join(config['save_path'], 'video/%04d.png'%(i))) | |
| if i not in sublists[batch_ind]: | |
| continue | |
| imgs += [img] | |
| if i != sublists[batch_ind][-1]: | |
| continue | |
| print('processing batch [%d/%d] with %d frames'%(batch_ind+1, len(sublists), len(sublists[batch_ind]))) | |
| # prepare input | |
| batch_size = len(imgs) | |
| n_prompts = [n_prompt] * len(imgs) | |
| prompts = [base_prompt + a_prompt + extra_prompts[ind] for ind in sublists[batch_ind]] | |
| if propagation_mode: # restore the extra_prompts from previous batch | |
| assert len(imgs) == len(sublists[batch_ind]) + 2 | |
| prompts = ref_prompt + prompts | |
| prompt_embeds = pipe._encode_prompt( | |
| prompts, | |
| device, | |
| 1, | |
| do_classifier_free_guidance, | |
| n_prompts, | |
| ) | |
| imgs_torch = torch.cat([numpy2tensor(img) for img in imgs], dim=0) | |
| edges = torch.cat([numpy2tensor(apply_control(img, detector, config)[:, :, None]) for img in imgs], dim=0) | |
| edges = edges.repeat(1,3,1,1).cuda() * 0.5 + 0.5 | |
| if do_classifier_free_guidance: | |
| edges = torch.cat([edges.to(pipe.unet.dtype)] * 2) | |
| if config['use_salinecy']: | |
| saliency = get_saliency(imgs, sod_model, dilate) | |
| else: | |
| saliency = None | |
| # prepare parameters for inter-frame and intra-frame consistency | |
| flows, occs, attn_mask, interattn_paras = get_flow_and_interframe_paras(flow_model, imgs) | |
| correlation_matrix = get_intraframe_paras(pipe, imgs_torch, frescoProc, | |
| prompt_embeds, seed = config['seed']) | |
| ''' | |
| Flexible settings for attention: | |
| * Turn off FRESCO-guided attention: frescoProc.controller.disable_controller() | |
| Then you can turn on one specific attention submodule | |
| * Turn on Cross-frame attention: frescoProc.controller.enable_cfattn(attn_mask) | |
| * Turn on Spatial-guided attention: frescoProc.controller.enable_intraattn() | |
| * Turn on Temporal-guided attention: frescoProc.controller.enable_interattn(interattn_paras) | |
| Flexible settings for optimization: | |
| * Turn off Spatial-guided optimization: set optimize_temporal = False in apply_FRESCO_opt() | |
| * Turn off Temporal-guided optimization: set correlation_matrix = [] in apply_FRESCO_opt() | |
| * Turn off FRESCO-guided optimization: disable_FRESCO_opt(pipe) | |
| Flexible settings for background smoothing: | |
| * Turn off background smoothing: set saliency = None in apply_FRESCO_opt() | |
| ''' | |
| # Turn on all FRESCO support | |
| frescoProc.controller.enable_controller(interattn_paras=interattn_paras, attn_mask=attn_mask) | |
| apply_FRESCO_opt(pipe, steps = timesteps[:config['end_opt_step']], | |
| flows = flows, occs = occs, correlation_matrix=correlation_matrix, | |
| saliency=saliency, optimize_temporal = True) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # run! | |
| latents = inference(pipe, controlnet, frescoProc, | |
| imgs_torch, prompt_embeds, edges, timesteps, | |
| cond_scale, config['num_inference_steps'], config['num_warmup_steps'], | |
| do_classifier_free_guidance, config['seed'], guidance_scale, config['use_controlnet'], | |
| record_latents, propagation_mode, | |
| flows = flows, occs = occs, saliency=saliency, repeat_noise=True) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| with torch.no_grad(): | |
| image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] | |
| image = torch.clamp(image, -1 , 1) | |
| save_imgs = tensor2numpy(image) | |
| bias = 2 if propagation_mode else 0 | |
| for ind, num in enumerate(sublists[batch_ind]): | |
| Image.fromarray(save_imgs[ind+bias]).save(os.path.join(config['save_path'], 'keys/%04d.png'%(num))) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| batch_ind += 1 | |
| # current batch uses the last frame of the previous batch as ref | |
| ref_prompt= [prompts[0], prompts[-1]] | |
| imgs = [imgs[0], imgs[-1]] | |
| propagation_mode = batch_ind > 0 | |
| if batch_ind == len(sublists): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| break | |
| return keys | |
| def run_full_video_translation(config, keys): | |
| print('\n' + '=' * 100) | |
| if not config['run_ebsynth']: | |
| print('to translate full video with ebsynth, install ebsynth and run:') | |
| else: | |
| print('translating full video with:') | |
| video_cap = cv2.VideoCapture(config['file_path']) | |
| fps = int(video_cap.get(cv2.CAP_PROP_FPS)) | |
| o_video = os.path.join(config['save_path'], 'blend.mp4') | |
| max_process = config['max_process'] | |
| save_path = config['save_path'] | |
| key_ind = io.StringIO() | |
| for k in keys: | |
| print('%d'%(k), end=' ', file=key_ind) | |
| cmd = ( | |
| f'python video_blend.py {save_path} --key keys ' | |
| f'--key_ind {key_ind.getvalue()} --output {o_video} --fps {fps} ' | |
| f'--n_proc {max_process} -ps') | |
| print('\n```') | |
| print(cmd) | |
| print('```') | |
| if config['run_ebsynth']: | |
| os.system(cmd) | |
| print('\n' + '=' * 100) | |
| print('Done') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('config_path', type=str, | |
| default='./config/config_carturn.yaml', | |
| help='The configuration file.') | |
| opt = parser.parse_args() | |
| print('=' * 100) | |
| print('loading configuration...') | |
| with open(opt.config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| for name, value in sorted(config.items()): | |
| print('%s: %s' % (str(name), str(value))) | |
| keys = run_keyframe_translation(config) | |
| run_full_video_translation(config, keys) | |