Spaces:
Sleeping
Sleeping
| import random | |
| import os | |
| import gc | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from diffusers import DDIMScheduler, StableDiffusionPipeline | |
| from pytorch_lightning import seed_everything | |
| import torch | |
| from scipy.ndimage import gaussian_filter | |
| import sys | |
| sys.path.append("./scripts") | |
| from dyn_mask import DynMask, get_surround | |
| from arguments import parse_args | |
| from clicker import ClickCreate, ClickDraw | |
| from augmentations import ImageAugmentations | |
| from constants import Const, N | |
| def read_image(image: Image.Image, device, dest_size): | |
| image = image.convert("RGB") | |
| image = image.resize(dest_size, Image.LANCZOS) if dest_size != image.size else image | |
| image = np.array(image) | |
| image = image.astype(np.float32) / 255.0 | |
| image = image[None].transpose(0, 3, 1, 2) | |
| image = torch.from_numpy(image).to(device) | |
| image = image * 2.0 - 1.0 | |
| return image | |
| class Click2Mask: | |
| def __init__(self): | |
| self.args = parse_args() | |
| self.device = torch.device(f"cuda:{self.args.gpu_id}") | |
| self.load_models() | |
| def load_models(self): | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| self.args.model_path, torch_dtype=torch.float16 | |
| ) | |
| self.vae = pipe.vae.to(self.device) | |
| self.tokenizer = pipe.tokenizer | |
| self.text_encoder = pipe.text_encoder.to(self.device) | |
| self.unet = pipe.unet.to(self.device) | |
| self.scheduler = DDIMScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| clip_sample=False, | |
| set_alpha_to_one=False, | |
| ) | |
| def blended_latent_diffusion( | |
| self, | |
| dyn_mask, | |
| create_dyn_mask, | |
| seed, | |
| original_rand_latents, | |
| scheduler, | |
| blending_percentage, | |
| total_steps, | |
| source_latents, | |
| text_embeddings, | |
| guidance_scale, | |
| dyn_start_step_i=None, | |
| dyn_cond_stop_step_i=None, | |
| dyn_final_stop_step_i=None, | |
| max_area_ratio_for_dilation=None, | |
| last_step_threshed_latent_mask=None, | |
| rerun_return_during_step_i=None, | |
| ): | |
| seed_everything(seed) | |
| use_plain_dilation_from_latent_mask = not create_dyn_mask | |
| blending_steps_t = scheduler.timesteps[ | |
| int(len(scheduler.timesteps) * blending_percentage) : | |
| ] | |
| latents = original_rand_latents | |
| if create_dyn_mask: | |
| update_steps = list(range(dyn_start_step_i, dyn_cond_stop_step_i + 1)) | |
| update_steps = [u for u in update_steps if 0 != u < len(blending_steps_t)] | |
| first_update_step, orig_last_update_step = update_steps[0], update_steps[-1] | |
| best_step_i = orig_last_update_step | |
| if last_step_threshed_latent_mask is not None: | |
| latent_mask = last_step_threshed_latent_mask | |
| for step_i, t in enumerate(blending_steps_t): | |
| # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = scheduler.scale_model_input( | |
| latent_model_input, timestep=t | |
| ) | |
| # predict the noise residual | |
| with torch.no_grad(): | |
| noise_pred = self.unet( | |
| latent_model_input, t, encoder_hidden_states=text_embeddings | |
| ).sample | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| latent_pred_z0 = scheduler.step(noise_pred, t, latents).pred_original_sample | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = scheduler.step(noise_pred, t, latents).prev_sample | |
| if rerun_return_during_step_i == step_i: | |
| return latents, latent_mask | |
| # dilation for rerun + final runs | |
| elif use_plain_dilation_from_latent_mask: | |
| latent_mask = dyn_mask.get_plain_dilated_latent_mask( | |
| last_step_latent_mask=last_step_threshed_latent_mask, | |
| step_i=step_i, | |
| total_steps=total_steps, | |
| max_area_ratio_for_dilation=max_area_ratio_for_dilation, | |
| rerun_dyn_start_step_i=None | |
| if not rerun_return_during_step_i | |
| else dyn_start_step_i, | |
| ) | |
| # mask evolution | |
| elif create_dyn_mask: | |
| if step_i in update_steps: | |
| latent_mask = dyn_mask.evolve_mask( | |
| step_i=step_i, | |
| decoder=self.vae.decode, | |
| latent_pred_z0=latent_pred_z0, | |
| source_latents=source_latents, | |
| return_only=N.LATENT_MASK, | |
| ) | |
| # Rerun | |
| latents, _ = self.blended_latent_diffusion( | |
| dyn_mask, | |
| create_dyn_mask=False, | |
| seed=seed, | |
| original_rand_latents=original_rand_latents, | |
| scheduler=scheduler, | |
| blending_percentage=blending_percentage, | |
| total_steps=total_steps, | |
| source_latents=source_latents, | |
| text_embeddings=text_embeddings, | |
| guidance_scale=guidance_scale, | |
| dyn_start_step_i=dyn_start_step_i, | |
| max_area_ratio_for_dilation=Const.RERUN_MAX_AREA_RATIO_FOR_DILATION, | |
| last_step_threshed_latent_mask=latent_mask, | |
| rerun_return_during_step_i=step_i, | |
| ) | |
| elif step_i < first_update_step: # initial dilation | |
| latent_mask = dyn_mask.set_cur_masks( | |
| step_i=step_i, return_only=N.LATENT_MASK | |
| ) | |
| # Blending | |
| noise_source_latents = scheduler.add_noise( | |
| source_latents, torch.randn_like(latents), t | |
| ) | |
| latents = latents * latent_mask + noise_source_latents * (1 - latent_mask) | |
| if create_dyn_mask: | |
| if step_i >= orig_last_update_step: | |
| dyn_mask.make_cached_masks_clones(name=step_i) | |
| dyn_mask.latents_hist[step_i] = latents | |
| dyn_mask.latent_masks_hist[step_i] = latent_mask | |
| if step_i >= orig_last_update_step + 2: | |
| step_prev1_better = ( | |
| dyn_mask.closs_hist[step_i - 1] | |
| < dyn_mask.closs_hist[step_i - 2] | |
| ) | |
| if step_prev1_better: | |
| best_step_i = step_i - 1 | |
| if (not step_prev1_better) or (step_i > dyn_final_stop_step_i): | |
| # we need an extra step to calculate clip loss for last evolved mask | |
| latents = dyn_mask.latents_hist[best_step_i] | |
| latent_mask = dyn_mask.latent_masks_hist[best_step_i] | |
| dyn_mask.set_masks_from_cached_masks_clones( | |
| name=best_step_i | |
| ) | |
| break | |
| update_steps.append(step_i + 1) | |
| return latents, latent_mask | |
| def edit_image( | |
| self, | |
| image_pil, | |
| click_pil, | |
| prompts, | |
| height, | |
| width, | |
| num_inference_steps, | |
| num_static_inference_steps, | |
| guidance_scale, | |
| seed, | |
| blending_percentage, | |
| ): | |
| generator = torch.manual_seed(seed) | |
| batch_size = len(prompts) | |
| self.scheduler.set_timesteps(num_inference_steps) | |
| image_pil = image_pil.resize((height, width), Image.LANCZOS) | |
| image_np = np.array(image_pil)[:, :, :3] | |
| source_latents = self._image2latent(image_np) | |
| init_image_tensor = read_image( | |
| image=image_pil, device=self.device, dest_size=(height, width) | |
| ) | |
| total_steps = num_inference_steps - int( | |
| len(self.scheduler.timesteps) * blending_percentage | |
| ) | |
| dyn_mask = DynMask( | |
| click_pil, self.args, init_image_tensor, self.device, total_steps | |
| ) | |
| text_input = self.tokenizer( | |
| prompts, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | |
| max_length = text_input.input_ids.shape[-1] | |
| uncond_input = self.tokenizer( | |
| [""] * batch_size, | |
| padding="max_length", | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ) | |
| uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| latents = torch.randn( | |
| (batch_size, self.unet.config.in_channels, height // 8, width // 8), | |
| generator=generator, | |
| ) | |
| latents = latents.to(self.device).half() | |
| original_rand_latents = latents | |
| dyn_start_step_i = ( | |
| Const.DYN_START | |
| if Const.DYN_START > 1 | |
| else round(Const.DYN_START * total_steps) | |
| ) | |
| dyn_cond_stop_step_i = ( | |
| Const.DYN_COND_STOP | |
| if Const.DYN_COND_STOP > 1 | |
| else round(Const.DYN_COND_STOP * total_steps) | |
| ) | |
| dyn_final_stop_step_i = ( | |
| Const.DYN_FINAL_STOP | |
| if Const.DYN_FINAL_STOP > 1 | |
| else round(Const.DYN_FINAL_STOP * total_steps) | |
| ) | |
| # Evolve mask | |
| self.blended_latent_diffusion( | |
| dyn_mask=dyn_mask, | |
| create_dyn_mask=True, | |
| seed=seed, | |
| original_rand_latents=original_rand_latents, | |
| scheduler=self.scheduler, | |
| blending_percentage=blending_percentage, | |
| total_steps=total_steps, | |
| source_latents=source_latents, | |
| text_embeddings=text_embeddings, | |
| guidance_scale=guidance_scale, | |
| dyn_start_step_i=dyn_start_step_i, | |
| dyn_cond_stop_step_i=dyn_cond_stop_step_i, | |
| dyn_final_stop_step_i=dyn_final_stop_step_i, | |
| ) | |
| # Final run | |
| self.static_scheduler = DDIMScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| clip_sample=False, | |
| set_alpha_to_one=False, | |
| ) | |
| self.static_scheduler.set_timesteps(num_static_inference_steps) | |
| total_static_steps = num_static_inference_steps - int( | |
| len(self.static_scheduler.timesteps) * blending_percentage | |
| ) | |
| latents_list = [] | |
| latent_masks_list = [] | |
| seeds_list = [] | |
| seeds_to_run = random.sample(range(1, Const.MAX_SEED), Const.N_OUTS_FOR_DYN_MASK - 1) | |
| print(f"running output (from {Const.N_OUTS_FOR_DYN_MASK}): ", end="") | |
| for out_i in range(Const.N_OUTS_FOR_DYN_MASK): | |
| print(f"{out_i + 1}", end="... ") | |
| orig_l = original_rand_latents | |
| seed_i = seed | |
| if out_i > 0: | |
| seed_i = seeds_to_run[out_i - 1] | |
| orig_l = torch.randn( | |
| (batch_size, self.unet.config.in_channels, height // 8, width // 8), | |
| generator=torch.manual_seed(seed_i), | |
| ) | |
| orig_l = orig_l.to(self.device).half() | |
| latents, latent_mask = self.blended_latent_diffusion( | |
| dyn_mask=dyn_mask, | |
| create_dyn_mask=False, | |
| seed=seed_i, | |
| original_rand_latents=orig_l, | |
| scheduler=self.static_scheduler if out_i > 0 else self.scheduler, | |
| blending_percentage=blending_percentage, | |
| total_steps=total_static_steps if out_i > 0 else total_steps, | |
| source_latents=source_latents, | |
| text_embeddings=text_embeddings, | |
| guidance_scale=guidance_scale, | |
| max_area_ratio_for_dilation=Const.MAX_AREA_RATIO_FOR_DILATION, | |
| last_step_threshed_latent_mask=dyn_mask.get_curr_masks( | |
| return_only=N.LATENT_MASK | |
| ), | |
| ) | |
| latents_list.append(latents) | |
| latent_masks_list.append(latent_mask) | |
| seeds_list.append(seed_i) | |
| print("scoring...") | |
| results = self.score_and_arrange_results( | |
| dyn_mask=dyn_mask, | |
| latents_list=latents_list, | |
| latent_masks_list=latent_masks_list, | |
| n_runs=Const.N_RUNS_ON_SCORES, | |
| aug_num=Const.N_AUGS_ON_SCORES, | |
| alpha_mask_dilation_on_512=Const.ALPHA_MASK_DILATION_ON_512, | |
| ) | |
| return results | |
| def _image2latent(self, image): | |
| image = torch.from_numpy(image).float() / 127.5 - 1 | |
| image = image.permute(2, 0, 1).unsqueeze(0).to(self.device) | |
| image = image.half() | |
| latents = self.vae.encode(image)["latent_dist"].mean | |
| latents = latents * 0.18215 | |
| return latents | |
| def back_preserve_with_gauss(self, decoded_img, latent_mask, dyn_mask): | |
| upsampled_mask = latent_mask.cpu().numpy().squeeze() | |
| upsampled_mask = cv2.resize( | |
| upsampled_mask.squeeze().astype(np.float32), | |
| dyn_mask.decoded_size, | |
| Image.LANCZOS, | |
| ) | |
| upsampled_mask = upsampled_mask > 0.5 | |
| g_mask = gaussian_filter( | |
| upsampled_mask.astype(float), sigma=Const.BACK_PRES_SIGMA | |
| ) | |
| g_mask = torch.from_numpy(g_mask).half().to(self.device) | |
| g_mask = (g_mask * Const.BACK_PRES_SCALE).clip(0, 1) | |
| g_mask[upsampled_mask > 0.5] = 1 | |
| blended = decoded_img * g_mask + dyn_mask.init_image * (1 - g_mask) | |
| return blended | |
| def score_and_arrange_results( | |
| self, | |
| dyn_mask, | |
| latents_list, | |
| latent_masks_list, | |
| n_runs, | |
| aug_num, | |
| alpha_mask_dilation_on_512, | |
| ): | |
| results = [] | |
| raw_d_prompt = np.zeros((n_runs, len(latents_list))) | |
| for i, (latents, latent_mask) in enumerate( | |
| zip(latents_list, latent_masks_list) | |
| ): | |
| latents = 1 / 0.18215 * latents | |
| with torch.no_grad(): | |
| img = self.vae.decode(latents).sample | |
| img = self.back_preserve_with_gauss(img, latent_mask, dyn_mask) | |
| results.append({"im": img, "latent_mask": latent_mask}) | |
| alpha_mask = get_surround( | |
| latent_mask, | |
| alpha_mask_dilation_on_512 * (latent_mask.shape[-1] / 512.0), | |
| self.device, | |
| ) | |
| if aug_num is not None: | |
| image_augmentations = ImageAugmentations( | |
| self.args.alpha_clip_scale, aug_num | |
| ) | |
| else: | |
| image_augmentations = None | |
| for run_i in range(n_runs): | |
| raw_d_prompt[run_i][i] = dyn_mask.alpha_clip_loss( | |
| img, | |
| alpha_mask, | |
| dyn_mask.text_features, | |
| image_augmentations=image_augmentations, | |
| augs_with_orig=(run_i == 0), | |
| return_as_similarity=True, | |
| ) | |
| raw_d_prompt = raw_d_prompt.mean(axis=0) | |
| for i, res in enumerate(results): | |
| res["dist"] = float(raw_d_prompt[i]) | |
| return results | |
| def click2mask_app(prompt: str, image_pil: Image.Image, point512: np.ndarray): | |
| c2m = Click2Mask() | |
| c2m.args.prompt = prompt | |
| results = [] | |
| for mask_i in range(c2m.args.n_masks): | |
| print(f"\nEvolving mask {mask_i + 1}...") | |
| seed = ( | |
| c2m.args.seed | |
| if (c2m.args.seed and mask_i == 0) | |
| else random.sample(range(1, Const.MAX_SEED), 1)[0] | |
| ) | |
| seed_everything(seed) | |
| click_draw = ClickDraw() | |
| click_pil, _ = click_draw(image_pil, point512=point512) | |
| mask_i_results = c2m.edit_image( | |
| image_pil=image_pil, | |
| click_pil=click_pil, | |
| prompts=[c2m.args.prompt] * Const.BATCH_SIZE, | |
| height=Const.H, | |
| width=Const.W, | |
| num_inference_steps=Const.NUM_INFERENCE_STEPS, | |
| num_static_inference_steps=Const.NUM_STATIC_INFERENCE_STEPS, | |
| guidance_scale=Const.GUIDANCE_SCALE, | |
| seed=seed, | |
| blending_percentage=Const.BLENDING_START_PERCENTAGE, | |
| ) | |
| results += mask_i_results | |
| sorted_results = sorted(results, key=lambda k: k["dist"], reverse=True) | |
| out_img = sorted_results[0]["im"] | |
| out_img = (out_img / 2 + 0.5).clamp(0, 1) | |
| out_img = out_img.detach().cpu().permute(0, 2, 3, 1).numpy().squeeze() | |
| out_img = (out_img * 255).round().astype(np.uint8) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print(f"\nCompleted.") | |
| return out_img | |
| if __name__ == "__main__": | |
| c2m = Click2Mask() | |
| img_dir = os.path.dirname(c2m.args.image_path) | |
| img_name = os.path.basename(os.path.normpath(c2m.args.image_path)) | |
| img_base_name = os.path.splitext(img_name)[0] | |
| results = [] | |
| for mask_i in range(c2m.args.n_masks): | |
| print(f"\nEvolving mask {mask_i + 1}...") | |
| seed = ( | |
| c2m.args.seed | |
| if (c2m.args.seed and mask_i == 0) | |
| else random.sample(range(1, Const.MAX_SEED), 1)[0] | |
| ) | |
| seed_everything(seed) | |
| click_ext = [ | |
| ext | |
| for ext in ("jpg", "JPG", "JPEG", "jpeg", "png", "PNG") | |
| if os.path.exists(os.path.join(img_dir, f"{img_base_name}_click.{ext}")) | |
| ] | |
| if (not click_ext) or (mask_i == 0 and c2m.args.refresh_click): | |
| click_create = ClickCreate() | |
| c2m.args.click_path = click_create( | |
| c2m.args.image_path, os.path.join(img_dir, f"{img_base_name}_click.jpg") | |
| ) | |
| else: | |
| c2m.args.click_path = os.path.join( | |
| img_dir, f"{img_base_name}_click.{click_ext[0]}" | |
| ) | |
| mask_i_results = c2m.edit_image( | |
| image_pil=Image.open(c2m.args.image_path), | |
| click_pil=Image.open(c2m.args.click_path), | |
| prompts=[c2m.args.prompt] * Const.BATCH_SIZE, | |
| height=Const.H, | |
| width=Const.W, | |
| num_inference_steps=Const.NUM_INFERENCE_STEPS, | |
| num_static_inference_steps=Const.NUM_STATIC_INFERENCE_STEPS, | |
| guidance_scale=Const.GUIDANCE_SCALE, | |
| seed=seed, | |
| blending_percentage=Const.BLENDING_START_PERCENTAGE, | |
| ) | |
| results += mask_i_results | |
| os.makedirs(c2m.args.output_dir, exist_ok=True) | |
| sorted_results = sorted(results, key=lambda k: k["dist"], reverse=True) | |
| out_img = sorted_results[0]["im"] | |
| out_img = (out_img / 2 + 0.5).clamp(0, 1) | |
| out_img = out_img.detach().cpu().permute(0, 2, 3, 1).numpy().squeeze() | |
| out_img = (out_img * 255).round().astype(np.uint8) | |
| out_path = os.path.join(c2m.args.output_dir, f"{img_base_name}_out.jpg") | |
| Image.fromarray(out_img).save(out_path, quality=95) | |
| print(f"\nCompleted.\nOutput image path:\n{os.path.abspath(out_path)}") | |