Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| from optimization.constants import ASSETS_DIR_NAME, RANKED_RESULTS_DIR | |
| from utils.metrics_accumulator import MetricsAccumulator | |
| from utils.video import save_video | |
| from utils.fft_pytorch import HighFrequencyLoss | |
| from numpy import random | |
| from optimization.augmentations import ImageAugmentations | |
| from PIL import Image | |
| import torch | |
| import torchvision | |
| from torchvision import transforms | |
| import torchvision.transforms.functional as F | |
| from torchvision.transforms import functional as TF | |
| from torch.nn.functional import mse_loss | |
| from optimization.losses import range_loss, d_clip_loss | |
| import lpips | |
| import numpy as np | |
| from CLIP import clip | |
| from guided_diffusion.guided_diffusion.script_util import ( | |
| create_model_and_diffusion, | |
| model_and_diffusion_defaults, | |
| create_classifier, | |
| classifier_defaults, | |
| ) | |
| from utils.visualization import show_tensor_image, show_editied_masked_image | |
| from utils.change_place import change_place, find_bbox | |
| import pdb | |
| import cv2 | |
| def create_classifier_ours(): | |
| model = torchvision.models.resnet50() | |
| ckpt = torch.load('checkpoints/DRA_resnet50.pth')['model_state_dict'] | |
| model.load_state_dict({k.replace('module.','').replace('last_linear','fc'):v for k,v in ckpt.items()}) | |
| model = torch.nn.Sequential(*[torch.nn.Upsample(size=(256,256)), model]) | |
| return model | |
| class ImageEditor: | |
| def __init__(self, args) -> None: | |
| self.args = args | |
| os.makedirs(self.args.output_path, exist_ok=True) | |
| self.ranked_results_path = Path(os.path.join(self.args.output_path, RANKED_RESULTS_DIR)) | |
| os.makedirs(self.ranked_results_path, exist_ok=True) | |
| if self.args.export_assets: | |
| self.assets_path = Path(os.path.join(self.args.output_path, ASSETS_DIR_NAME)) | |
| os.makedirs(self.assets_path, exist_ok=True) | |
| if self.args.seed is not None: | |
| torch.manual_seed(self.args.seed) | |
| np.random.seed(self.args.seed) | |
| random.seed(self.args.seed) | |
| self.model_config = model_and_diffusion_defaults() | |
| self.model_config.update( | |
| { | |
| "attention_resolutions": "32, 16, 8", | |
| "class_cond": self.args.model_output_size == 512, | |
| "diffusion_steps": 1000, | |
| "rescale_timesteps": True, | |
| "timestep_respacing": self.args.timestep_respacing, | |
| "image_size": self.args.model_output_size, | |
| "learn_sigma": True, | |
| "noise_schedule": "linear", | |
| "num_channels": 256, | |
| "num_head_channels": 64, | |
| "num_res_blocks": 2, | |
| "resblock_updown": True, | |
| "use_fp16": True, | |
| "use_scale_shift_norm": True, | |
| } | |
| ) | |
| self.classifier_config = classifier_defaults() | |
| self.classifier_config.update( | |
| { | |
| "image_size": self.args.model_output_size, | |
| } | |
| ) | |
| # Load models | |
| self.device = torch.device( | |
| f"cuda:{self.args.gpu_id}" if torch.cuda.is_available() else "cpu" | |
| ) | |
| print("Using device:", self.device) | |
| self.model, self.diffusion = create_model_and_diffusion(**self.model_config) | |
| self.model.load_state_dict( | |
| torch.load( | |
| "checkpoints/256x256_diffusion_uncond.pt" | |
| if self.args.model_output_size == 256 | |
| else "checkpoints/512x512_diffusion.pt", | |
| map_location="cpu", | |
| ) | |
| ) | |
| # self.model.requires_grad_(False).eval().to(self.device) | |
| self.model.eval().to(self.device) | |
| for name, param in self.model.named_parameters(): | |
| if "qkv" in name or "norm" in name or "proj" in name: | |
| param.requires_grad_() | |
| if self.model_config["use_fp16"]: | |
| self.model.convert_to_fp16() | |
| self.classifier = create_classifier(**self.classifier_config) | |
| self.classifier.load_state_dict( | |
| torch.load("checkpoints/256x256_classifier.pt", map_location="cpu") | |
| ) | |
| # self.classifier.requires_grad_(False).eval().to(self.device) | |
| # self.classifier = create_classifier_ours() | |
| self.classifier.eval().to(self.device) | |
| if self.classifier_config["classifier_use_fp16"]: | |
| self.classifier.convert_to_fp16() | |
| self.clip_model = ( | |
| clip.load("ViT-B/16", device=self.device, jit=False)[0].eval().requires_grad_(False) | |
| ) | |
| self.clip_size = self.clip_model.visual.input_resolution | |
| self.clip_normalize = transforms.Normalize( | |
| mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] | |
| ) | |
| self.to_tensor = transforms.ToTensor() | |
| self.lpips_model = lpips.LPIPS(net="vgg").to(self.device) | |
| self.image_augmentations = ImageAugmentations(self.clip_size, self.args.aug_num) | |
| self.metrics_accumulator = MetricsAccumulator() | |
| self.hf_loss = HighFrequencyLoss() | |
| def unscale_timestep(self, t): | |
| unscaled_timestep = (t * (self.diffusion.num_timesteps / 1000)).long() | |
| return unscaled_timestep | |
| def clip_loss(self, x_in, text_embed): | |
| clip_loss = torch.tensor(0) | |
| if self.mask is not None: | |
| masked_input = x_in * self.mask | |
| else: | |
| masked_input = x_in | |
| augmented_input = self.image_augmentations(masked_input).add(1).div(2) # shape: [N,C,H,W], range: [0,1] | |
| clip_in = self.clip_normalize(augmented_input) | |
| # pdb.set_trace() | |
| image_embeds = self.clip_model.encode_image(clip_in).float() | |
| dists = d_clip_loss(image_embeds, text_embed) | |
| # We want to sum over the averages | |
| for i in range(self.args.batch_size): | |
| # We want to average at the "augmentations level" | |
| clip_loss = clip_loss + dists[i :: self.args.batch_size].mean() | |
| return clip_loss | |
| def unaugmented_clip_distance(self, x, text_embed): | |
| x = F.resize(x, [self.clip_size, self.clip_size]) | |
| image_embeds = self.clip_model.encode_image(x).float() | |
| dists = d_clip_loss(image_embeds, text_embed) | |
| return dists.item() | |
| def model_fn(self, x,t,y=None): | |
| return self.model(x, t, y if self.args.class_cond else None) | |
| def edit_image_by_prompt(self): | |
| if self.args.image_guide: | |
| img_guidance = Image.open(self.args.prompt).convert('RGB') | |
| img_guidance = img_guidance.resize((224,224), Image.LANCZOS) # type: ignore | |
| img_guidance = self.clip_normalize(self.to_tensor(img_guidance).unsqueeze(0)).to(self.device) | |
| text_embed = self.clip_model.encode_image(img_guidance).float() | |
| else: | |
| text_embed = self.clip_model.encode_text( | |
| clip.tokenize(self.args.prompt).to(self.device) | |
| ).float() | |
| self.image_size = (self.model_config["image_size"], self.model_config["image_size"]) | |
| self.init_image_pil = Image.open(self.args.init_image).convert("RGB") | |
| self.init_image_pil = self.init_image_pil.resize(self.image_size, Image.LANCZOS) # type: ignore | |
| self.init_image = ( | |
| TF.to_tensor(self.init_image_pil).to(self.device).unsqueeze(0).mul(2).sub(1) | |
| ) | |
| self.init_image_pil_2 = Image.open(self.args.init_image_2).convert("RGB") | |
| if self.args.rotate_obj: | |
| # angle = random.randint(-45,45) | |
| angle = self.args.angle | |
| self.init_image_pil_2 = self.init_image_pil_2.rotate(angle) | |
| self.init_image_pil_2 = self.init_image_pil_2.resize(self.image_size, Image.LANCZOS) # type: ignore | |
| self.init_image_2 = ( | |
| TF.to_tensor(self.init_image_pil_2).to(self.device).unsqueeze(0).mul(2).sub(1) | |
| ) | |
| ''' | |
| # Init with the inpainting image | |
| self.init_image_pil_ = Image.open('output/ImageNet-S_val/bad_case_RN50/ILSVRC2012_val_00013212/ranked/08480_output_i_0_b_0.png').convert("RGB") | |
| self.init_image_pil_ = self.init_image_pil_.resize(self.image_size, Image.LANCZOS) # type: ignore | |
| self.init_image_ = ( | |
| TF.to_tensor(self.init_image_pil_).to(self.device).unsqueeze(0).mul(2).sub(1) | |
| ) | |
| ''' | |
| if self.args.export_assets: | |
| img_path = self.assets_path / Path(self.args.output_file) | |
| self.init_image_pil.save(img_path, quality=100) | |
| self.mask = torch.ones_like(self.init_image, device=self.device) | |
| self.mask_pil = None | |
| if self.args.mask is not None: | |
| self.mask_pil = Image.open(self.args.mask).convert("RGB") | |
| if self.args.rotate_obj: | |
| self.mask_pil = self.mask_pil.rotate(angle) | |
| if self.mask_pil.size != self.image_size: | |
| self.mask_pil = self.mask_pil.resize(self.image_size, Image.NEAREST) # type: ignore | |
| if self.args.random_position: | |
| bbox = find_bbox(np.array(self.mask_pil)) | |
| print(bbox) | |
| image_mask_pil_binarized = ((np.array(self.mask_pil) > 0.5) * 255).astype(np.uint8) | |
| # image_mask_pil_binarized = cv2.dilate(image_mask_pil_binarized, np.ones((50,50), np.uint8), iterations=1) | |
| if self.args.invert_mask: | |
| image_mask_pil_binarized = 255 - image_mask_pil_binarized | |
| self.mask_pil = TF.to_pil_image(image_mask_pil_binarized) | |
| self.mask = TF.to_tensor(Image.fromarray(image_mask_pil_binarized)) | |
| self.mask = self.mask[0, ...].unsqueeze(0).unsqueeze(0).to(self.device) | |
| # self.mask[:] = 1 | |
| if self.args.random_position: | |
| # print(self.init_image_2.shape, self.init_image_2.max(), self.init_image_2.min()) | |
| # print(self.mask.shape, self.mask.max(), self.mask.min()) | |
| # cv2.imwrite('tmp/init_before.jpg', np.transpose(((self.init_image_2+1)/2*255).cpu().numpy()[0], (1,2,0))[:,:,::-1]) | |
| # cv2.imwrite('tmp/mask_before.jpg', (self.mask*255).cpu().numpy()[0][0]) | |
| self.init_image_2, self.mask = change_place(self.init_image_2, self.mask, bbox, self.args.invert_mask) | |
| # cv2.imwrite('tmp/init_after.jpg', np.transpose(((self.init_image_2+1)/2*255).cpu().numpy()[0], (1,2,0))[:,:,::-1]) | |
| # cv2.imwrite('tmp/mask_after.jpg', (self.mask*255).cpu().numpy()[0][0]) | |
| if self.args.export_assets: | |
| mask_path = self.assets_path / Path( | |
| self.args.output_file.replace(".png", "_mask.png") | |
| ) | |
| self.mask_pil.save(mask_path, quality=100) | |
| def class_guided(x, y, t): | |
| assert y is not None | |
| with torch.enable_grad(): | |
| x_in = x.detach().requires_grad_(True) | |
| # logits = self.classifier(x_in, t) | |
| logits = self.classifier(x_in) | |
| log_probs = torch.nn.functional.log_softmax(logits, dim=-1) | |
| selected = log_probs[range(len(logits)), y.view(-1)] | |
| loss = selected.sum() | |
| return -torch.autograd.grad(loss, x_in)[0] * self.args.classifier_scale | |
| def cond_fn(x, t, y=None): | |
| if self.args.prompt == "": | |
| return torch.zeros_like(x) | |
| # pdb.set_trace() | |
| with torch.enable_grad(): | |
| x = x.detach().requires_grad_() | |
| t_unscale = self.unscale_timestep(t) | |
| ''' | |
| out = self.diffusion.p_mean_variance( | |
| self.model, x, t, clip_denoised=False, model_kwargs={"y": y} | |
| ) | |
| ''' | |
| out = self.diffusion.p_mean_variance( | |
| self.model, x, t_unscale, clip_denoised=False, model_kwargs={"y": None} | |
| ) | |
| fac = self.diffusion.sqrt_one_minus_alphas_cumprod[t_unscale[0].item()] | |
| # x_in = out["pred_xstart"] * fac + x * (1 - fac) | |
| x_in = out["pred_xstart"] # Revised by XX, 2022.07.14 | |
| loss = torch.tensor(0) | |
| if self.args.classifier_scale != 0 and y is not None: | |
| # gradient_class_guided = class_guided(x, y, t) | |
| gradient_class_guided = class_guided(x_in, y, t) | |
| if self.args.background_complex != 0: | |
| if self.args.hard: | |
| loss = loss - self.args.background_complex*self.hf_loss((x_in+1.)/2.) | |
| else: | |
| loss = loss + self.args.background_complex*self.hf_loss((x_in+1.)/2.) | |
| if self.args.clip_guidance_lambda != 0: | |
| clip_loss = self.clip_loss(x_in, text_embed) * self.args.clip_guidance_lambda | |
| loss = loss + clip_loss | |
| self.metrics_accumulator.update_metric("clip_loss", clip_loss.item()) | |
| if self.args.range_lambda != 0: | |
| r_loss = range_loss(out["pred_xstart"]).sum() * self.args.range_lambda | |
| loss = loss + r_loss | |
| self.metrics_accumulator.update_metric("range_loss", r_loss.item()) | |
| if self.args.background_preservation_loss: | |
| x_in = out["pred_xstart"] * fac + x * (1 - fac) | |
| if self.mask is not None: | |
| # masked_background = x_in * (1 - self.mask) | |
| masked_background = x_in * self.mask # 2022.07.19 | |
| else: | |
| masked_background = x_in | |
| if self.args.lpips_sim_lambda: | |
| ''' | |
| loss = ( | |
| loss | |
| + self.lpips_model(masked_background, self.init_image).sum() | |
| * self.args.lpips_sim_lambda | |
| ) | |
| ''' | |
| # 2022.07.19 | |
| loss = ( | |
| loss | |
| + self.lpips_model(masked_background, self.init_image*self.mask).sum() | |
| * self.args.lpips_sim_lambda | |
| ) | |
| if self.args.l2_sim_lambda: | |
| ''' | |
| loss = ( | |
| loss | |
| + mse_loss(masked_background, self.init_image) * self.args.l2_sim_lambda | |
| ) | |
| ''' | |
| # 2022.07.19 | |
| loss = ( | |
| loss | |
| + mse_loss(masked_background, self.init_image*self.mask) * self.args.l2_sim_lambda | |
| ) | |
| if self.args.classifier_scale != 0 and y is not None: | |
| return -torch.autograd.grad(loss, x)[0] + gradient_class_guided | |
| else: | |
| return -torch.autograd.grad(loss, x)[0] | |
| def postprocess_fn(out, t): | |
| if self.args.coarse_to_fine: | |
| if t > 50: | |
| kernel = 51 | |
| elif t > 35: | |
| kernel = 31 | |
| else: | |
| kernel = 0 | |
| if kernel > 0: | |
| max_pool = torch.nn.MaxPool2d(kernel_size=kernel, stride=1, padding=int((kernel-1)/2)) | |
| self.mask_d = 1 - self.mask | |
| self.mask_d = max_pool(self.mask_d) | |
| self.mask_d = 1 - self.mask_d | |
| else: | |
| self.mask_d = self.mask | |
| else: | |
| self.mask_d = self.mask | |
| if self.mask is not None: | |
| background_stage_t = self.diffusion.q_sample(self.init_image_2, t[0]) | |
| background_stage_t = torch.tile( | |
| background_stage_t, dims=(self.args.batch_size, 1, 1, 1) | |
| ) | |
| out["sample"] = out["sample"] * self.mask_d + background_stage_t * (1 - self.mask_d) | |
| return out | |
| save_image_interval = self.diffusion.num_timesteps // 5 | |
| for iteration_number in range(self.args.iterations_num): | |
| print(f"Start iterations {iteration_number}") | |
| sample_func = ( | |
| self.diffusion.ddim_sample_loop_progressive | |
| if self.args.ddim | |
| else self.diffusion.p_sample_loop_progressive | |
| ) | |
| samples = sample_func( | |
| self.model_fn, | |
| ( | |
| self.args.batch_size, | |
| 3, | |
| self.model_config["image_size"], | |
| self.model_config["image_size"], | |
| ), | |
| clip_denoised=False, | |
| # model_kwargs={} | |
| # if self.args.model_output_size == 256 | |
| # else { | |
| # "y": torch.zeros([self.args.batch_size], device=self.device, dtype=torch.long) | |
| # }, | |
| model_kwargs={} | |
| if self.args.classifier_scale == 0 | |
| else {"y": self.args.y*torch.ones([self.args.batch_size], device=self.device, dtype=torch.long)}, | |
| cond_fn=cond_fn, | |
| device=self.device, | |
| progress=True, | |
| skip_timesteps=self.args.skip_timesteps, | |
| init_image=self.init_image, | |
| # init_image=self.init_image_, | |
| postprocess_fn=None if self.args.local_clip_guided_diffusion else postprocess_fn, | |
| randomize_class=True if self.args.classifier_scale == 0 else False, | |
| ) | |
| intermediate_samples = [[] for i in range(self.args.batch_size)] | |
| total_steps = self.diffusion.num_timesteps - self.args.skip_timesteps - 1 | |
| for j, sample in enumerate(samples): | |
| should_save_image = j % save_image_interval == 0 or j == total_steps | |
| if should_save_image or self.args.save_video: | |
| self.metrics_accumulator.print_average_metric() | |
| for b in range(self.args.batch_size): | |
| pred_image = sample["pred_xstart"][b] | |
| visualization_path = Path( | |
| os.path.join(self.args.output_path, self.args.output_file) | |
| ) | |
| visualization_path = visualization_path.with_stem( | |
| f"{visualization_path.stem}_i_{iteration_number}_b_{b}" | |
| ) | |
| if ( | |
| self.mask is not None | |
| and self.args.enforce_background | |
| and j == total_steps | |
| and not self.args.local_clip_guided_diffusion | |
| ): | |
| pred_image = ( | |
| self.init_image_2[0] * (1 - self.mask[0]) + pred_image * self.mask[0] | |
| ) | |
| ''' | |
| if j == total_steps: | |
| pdb.set_trace() | |
| pred_image = ( | |
| self.init_image_2[0] * (1 - self.mask[0]) + pred_image * self.mask[0] | |
| ) | |
| ''' | |
| pred_image = pred_image.add(1).div(2).clamp(0, 1) | |
| pred_image_pil = TF.to_pil_image(pred_image) | |
| masked_pred_image = self.mask * pred_image.unsqueeze(0) | |
| final_distance = self.unaugmented_clip_distance( | |
| masked_pred_image, text_embed | |
| ) | |
| formatted_distance = f"{final_distance:.4f}" | |
| if self.args.export_assets: | |
| pred_path = self.assets_path / visualization_path.name | |
| pred_image_pil.save(pred_path, quality=100) | |
| if j == total_steps: | |
| path_friendly_distance = formatted_distance.replace(".", "") | |
| ranked_pred_path = self.ranked_results_path / ( | |
| path_friendly_distance + "_" + visualization_path.name | |
| ) | |
| pred_image_pil.save(ranked_pred_path, quality=100) | |
| intermediate_samples[b].append(pred_image_pil) | |
| if should_save_image: | |
| show_editied_masked_image( | |
| title=self.args.prompt, | |
| source_image=self.init_image_pil, | |
| edited_image=pred_image_pil, | |
| mask=self.mask_pil, | |
| path=visualization_path, | |
| distance=formatted_distance, | |
| ) | |
| if self.args.save_video: | |
| for b in range(self.args.batch_size): | |
| video_name = self.args.output_file.replace( | |
| ".png", f"_i_{iteration_number}_b_{b}.avi" | |
| ) | |
| video_path = os.path.join(self.args.output_path, video_name) | |
| save_video(intermediate_samples[b], video_path) | |
| visualize_size = (256,256) | |
| img_ori = cv2.imread(self.args.init_image_2) | |
| img_ori = cv2.resize(img_ori, visualize_size) | |
| mask = cv2.imread(self.args.mask) | |
| mask = cv2.resize(mask, visualize_size) | |
| imgs = [img_ori, mask] | |
| for ii, img_name in enumerate(os.listdir(os.path.join(self.args.output_path, 'ranked'))): | |
| img_path = os.path.join(self.args.output_path, 'ranked', img_name) | |
| img = cv2.imread(img_path) | |
| img = cv2.resize(img, visualize_size) | |
| imgs.append(img) | |
| if ii >= 7: | |
| break | |
| img_whole = cv2.hconcat(imgs[2:]) | |
| ''' | |
| img_name = self.args.output_path.split('/')[-2]+'/' | |
| if self.args.coarse_to_fine: | |
| if self.args.clip_guidance_lambda == 0: | |
| prompt = 'coarse_to_fine_no_clip' | |
| else: | |
| prompt = 'coarse_to_fine' | |
| elif self.args.image_guide: | |
| prompt = 'image_guide' | |
| elif self.args.clip_guidance_lambda == 0: | |
| prompt = 'no_clip_guide' | |
| else: | |
| prompt = 'text_guide' | |
| ''' | |
| cv2.imwrite(os.path.join(self.args.final_save_root, 'edited.png'), img_whole, [int(cv2.IMWRITE_PNG_COMPRESSION), 0]) | |
| def reconstruct_image(self): | |
| init = Image.open(self.args.init_image).convert("RGB") | |
| init = init.resize( | |
| self.image_size, # type: ignore | |
| Image.LANCZOS, | |
| ) | |
| init = TF.to_tensor(init).to(self.device).unsqueeze(0).mul(2).sub(1) | |
| samples = self.diffusion.p_sample_loop_progressive( | |
| self.model, | |
| (1, 3, self.model_config["image_size"], self.model_config["image_size"],), | |
| clip_denoised=False, | |
| model_kwargs={} | |
| if self.args.model_output_size == 256 | |
| else {"y": torch.zeros([self.args.batch_size], device=self.device, dtype=torch.long)}, | |
| cond_fn=None, | |
| progress=True, | |
| skip_timesteps=self.args.skip_timesteps, | |
| init_image=init, | |
| randomize_class=True, | |
| ) | |
| save_image_interval = self.diffusion.num_timesteps // 5 | |
| max_iterations = self.diffusion.num_timesteps - self.args.skip_timesteps - 1 | |
| for j, sample in enumerate(samples): | |
| if j % save_image_interval == 0 or j == max_iterations: | |
| print() | |
| filename = os.path.join(self.args.output_path, self.args.output_file) | |
| TF.to_pil_image(sample["pred_xstart"][0].add(1).div(2).clamp(0, 1)).save(filename) | |