Spaces:
Runtime error
Runtime error
| # ************************************************************************* | |
| # Copyright (2023) Bytedance Inc. | |
| # | |
| # Copyright (2023) DragDiffusion Authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ************************************************************************* | |
| # run results of DragDiffusion | |
| import argparse | |
| import os | |
| import datetime | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import pickle | |
| import PIL | |
| from PIL import Image | |
| from copy import deepcopy | |
| from einops import rearrange | |
| from types import SimpleNamespace | |
| from diffusers import DDIMScheduler, AutoencoderKL | |
| from torchvision.utils import save_image | |
| from pytorch_lightning import seed_everything | |
| import sys | |
| sys.path.insert(0, '../') | |
| from drag_pipeline import DragPipeline | |
| from utils.drag_utils import drag_diffusion_update | |
| from utils.attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl | |
| def preprocess_image(image, | |
| device): | |
| image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1] | |
| image = rearrange(image, "h w c -> 1 c h w") | |
| image = image.to(device) | |
| return image | |
| # copy the run_drag function to here | |
| def run_drag(source_image, | |
| # image_with_clicks, | |
| mask, | |
| prompt, | |
| points, | |
| inversion_strength, | |
| lam, | |
| latent_lr, | |
| unet_feature_idx, | |
| n_pix_step, | |
| model_path, | |
| vae_path, | |
| lora_path, | |
| start_step, | |
| start_layer, | |
| # save_dir="./results" | |
| ): | |
| # initialize model | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, | |
| beta_schedule="scaled_linear", clip_sample=False, | |
| set_alpha_to_one=False, steps_offset=1) | |
| model = DragPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) | |
| # call this function to override unet forward function, | |
| # so that intermediate features are returned after forward | |
| model.modify_unet_forward() | |
| # set vae | |
| if vae_path != "default": | |
| model.vae = AutoencoderKL.from_pretrained( | |
| vae_path | |
| ).to(model.vae.device, model.vae.dtype) | |
| # initialize parameters | |
| seed = 42 # random seed used by a lot of people for unknown reason | |
| seed_everything(seed) | |
| args = SimpleNamespace() | |
| args.prompt = prompt | |
| args.points = points | |
| args.n_inference_step = 50 | |
| args.n_actual_inference_step = round(inversion_strength * args.n_inference_step) | |
| args.guidance_scale = 1.0 | |
| args.unet_feature_idx = [unet_feature_idx] | |
| args.r_m = 1 | |
| args.r_p = 3 | |
| args.lam = lam | |
| args.lr = latent_lr | |
| args.n_pix_step = n_pix_step | |
| full_h, full_w = source_image.shape[:2] | |
| args.sup_res_h = int(0.5*full_h) | |
| args.sup_res_w = int(0.5*full_w) | |
| print(args) | |
| source_image = preprocess_image(source_image, device) | |
| # image_with_clicks = preprocess_image(image_with_clicks, device) | |
| # set lora | |
| if lora_path == "": | |
| print("applying default parameters") | |
| model.unet.set_default_attn_processor() | |
| else: | |
| print("applying lora: " + lora_path) | |
| model.unet.load_attn_procs(lora_path) | |
| # invert the source image | |
| # the latent code resolution is too small, only 64*64 | |
| invert_code = model.invert(source_image, | |
| prompt, | |
| guidance_scale=args.guidance_scale, | |
| num_inference_steps=args.n_inference_step, | |
| num_actual_inference_steps=args.n_actual_inference_step) | |
| mask = torch.from_numpy(mask).float() / 255. | |
| mask[mask > 0.0] = 1.0 | |
| mask = rearrange(mask, "h w -> 1 1 h w").cuda() | |
| mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest") | |
| handle_points = [] | |
| target_points = [] | |
| # here, the point is in x,y coordinate | |
| for idx, point in enumerate(points): | |
| cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w]) | |
| cur_point = torch.round(cur_point) | |
| if idx % 2 == 0: | |
| handle_points.append(cur_point) | |
| else: | |
| target_points.append(cur_point) | |
| print('handle points:', handle_points) | |
| print('target points:', target_points) | |
| init_code = invert_code | |
| init_code_orig = deepcopy(init_code) | |
| model.scheduler.set_timesteps(args.n_inference_step) | |
| t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] | |
| # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] | |
| # update according to the given supervision | |
| updated_init_code = drag_diffusion_update(model, init_code, | |
| None, t, handle_points, target_points, mask, args) | |
| # hijack the attention module | |
| # inject the reference branch to guide the generation | |
| editor = MutualSelfAttentionControl(start_step=start_step, | |
| start_layer=start_layer, | |
| total_steps=args.n_inference_step, | |
| guidance_scale=args.guidance_scale) | |
| if lora_path == "": | |
| register_attention_editor_diffusers(model, editor, attn_processor='attn_proc') | |
| else: | |
| register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc') | |
| # inference the synthesized image | |
| gen_image = model( | |
| prompt=args.prompt, | |
| batch_size=2, | |
| latents=torch.cat([init_code_orig, updated_init_code], dim=0), | |
| guidance_scale=args.guidance_scale, | |
| num_inference_steps=args.n_inference_step, | |
| num_actual_inference_steps=args.n_actual_inference_step | |
| )[1].unsqueeze(dim=0) | |
| # resize gen_image into the size of source_image | |
| # we do this because shape of gen_image will be rounded to multipliers of 8 | |
| gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear') | |
| # save the original image, user editing instructions, synthesized image | |
| # save_result = torch.cat([ | |
| # source_image * 0.5 + 0.5, | |
| # torch.ones((1,3,full_h,25)).cuda(), | |
| # image_with_clicks * 0.5 + 0.5, | |
| # torch.ones((1,3,full_h,25)).cuda(), | |
| # gen_image[0:1] | |
| # ], dim=-1) | |
| # if not os.path.isdir(save_dir): | |
| # os.mkdir(save_dir) | |
| # save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") | |
| # save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) | |
| out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] | |
| out_image = (out_image * 255).astype(np.uint8) | |
| return out_image | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description="setting arguments") | |
| parser.add_argument('--lora_steps', type=int, help='number of lora fine-tuning steps') | |
| parser.add_argument('--inv_strength', type=float, help='inversion strength') | |
| parser.add_argument('--latent_lr', type=float, default=0.01, help='latent learning rate') | |
| parser.add_argument('--unet_feature_idx', type=int, default=3, help='feature idx of unet features') | |
| args = parser.parse_args() | |
| all_category = [ | |
| 'art_work', | |
| 'land_scape', | |
| 'building_city_view', | |
| 'building_countryside_view', | |
| 'animals', | |
| 'human_head', | |
| 'human_upper_body', | |
| 'human_full_body', | |
| 'interior_design', | |
| 'other_objects', | |
| ] | |
| # assume root_dir and lora_dir are valid directory | |
| root_dir = 'drag_bench_data' | |
| lora_dir = 'drag_bench_lora' | |
| result_dir = 'drag_diffusion_res' + \ | |
| '_' + str(args.lora_steps) + \ | |
| '_' + str(args.inv_strength) + \ | |
| '_' + str(args.latent_lr) + \ | |
| '_' + str(args.unet_feature_idx) | |
| # mkdir if necessary | |
| if not os.path.isdir(result_dir): | |
| os.mkdir(result_dir) | |
| for cat in all_category: | |
| os.mkdir(os.path.join(result_dir,cat)) | |
| for cat in all_category: | |
| file_dir = os.path.join(root_dir, cat) | |
| for sample_name in os.listdir(file_dir): | |
| if sample_name == '.DS_Store': | |
| continue | |
| sample_path = os.path.join(file_dir, sample_name) | |
| # read image file | |
| source_image = Image.open(os.path.join(sample_path, 'original_image.png')) | |
| source_image = np.array(source_image) | |
| # load meta data | |
| with open(os.path.join(sample_path, 'meta_data.pkl'), 'rb') as f: | |
| meta_data = pickle.load(f) | |
| prompt = meta_data['prompt'] | |
| mask = meta_data['mask'] | |
| points = meta_data['points'] | |
| # load lora | |
| lora_path = os.path.join(lora_dir, cat, sample_name, str(args.lora_steps)) | |
| print("applying lora: " + lora_path) | |
| out_image = run_drag( | |
| source_image, | |
| mask, | |
| prompt, | |
| points, | |
| inversion_strength=args.inv_strength, | |
| lam=0.1, | |
| latent_lr=args.latent_lr, | |
| unet_feature_idx=args.unet_feature_idx, | |
| n_pix_step=80, | |
| model_path="runwayml/stable-diffusion-v1-5", | |
| vae_path="default", | |
| lora_path=lora_path, | |
| start_step=0, | |
| start_layer=10, | |
| ) | |
| save_dir = os.path.join(result_dir, cat, sample_name) | |
| if not os.path.isdir(save_dir): | |
| os.mkdir(save_dir) | |
| Image.fromarray(out_image).save(os.path.join(save_dir, 'dragged_image.png')) | |