| import pdb, sys |
|
|
| import numpy as np |
| import torch |
| from typing import Any, Callable, Dict, List, Optional, Union |
| from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
| sys.path.insert(0, "src/utils") |
| from base_pipeline import BasePipeline |
| from cross_attention import prep_unet |
|
|
|
|
| class EditingPipeline(BasePipeline): |
| def __call__( |
| self, |
| prompt: Union[str, List[str]] = None, |
| height: Optional[int] = None, |
| width: Optional[int] = None, |
| num_inference_steps: int = 50, |
| guidance_scale: float = 7.5, |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| num_images_per_prompt: Optional[int] = 1, |
| eta: float = 0.0, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| latents: Optional[torch.FloatTensor] = None, |
| prompt_embeds: Optional[torch.FloatTensor] = None, |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| |
| |
| guidance_amount=0.1, |
| edit_dir=None, |
| x_in=None, |
| |
| ): |
|
|
| x_in.to(dtype=self.unet.dtype, device=self._execution_device) |
|
|
| |
| self.unet = prep_unet(self.unet) |
| |
| |
| d_ref_t2attn = {} |
| |
| |
| height = height or self.unet.config.sample_size * self.vae_scale_factor |
| width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
| |
| |
|
|
| |
| if prompt is not None and isinstance(prompt, str): |
| batch_size = 1 |
| elif prompt is not None and isinstance(prompt, list): |
| batch_size = len(prompt) |
| else: |
| batch_size = prompt_embeds.shape[0] |
|
|
| device = self._execution_device |
| do_classifier_free_guidance = guidance_scale > 1.0 |
| x_in = x_in.to(dtype=self.unet.dtype, device=self._execution_device) |
| |
| prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,) |
|
|
| |
| self.scheduler.set_timesteps(num_inference_steps, device=device) |
| timesteps = self.scheduler.timesteps |
|
|
| |
| num_channels_latents = self.unet.in_channels |
| |
| |
| latents = self.prepare_latents(batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, x_in,) |
| |
| latents_init = latents.clone() |
| |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
| |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| with torch.no_grad(): |
| with self.progress_bar(total=num_inference_steps) as progress_bar: |
| for i, t in enumerate(timesteps): |
| |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
| |
| noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,).sample |
|
|
| |
| d_ref_t2attn[t.item()] = {} |
| for name, module in self.unet.named_modules(): |
| module_name = type(module).__name__ |
| if module_name == "CrossAttention" and 'attn2' in name: |
| attn_mask = module.attn_probs |
| d_ref_t2attn[t.item()][name] = attn_mask.detach().cpu() |
|
|
| |
| if do_classifier_free_guidance: |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
| |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
| |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| progress_bar.update() |
|
|
| |
| image_rec = self.numpy_to_pil(self.decode_latents(latents.detach())) |
|
|
| prompt_embeds_edit = prompt_embeds.clone() |
| |
| prompt_embeds_edit[1:2] += edit_dir |
| |
| latents = latents_init |
| |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| with self.progress_bar(total=num_inference_steps) as progress_bar: |
| for i, t in enumerate(timesteps): |
| |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
| x_in = latent_model_input.detach().clone() |
| x_in.requires_grad = True |
| |
| opt = torch.optim.SGD([x_in], lr=guidance_amount) |
|
|
| |
| noise_pred = self.unet(x_in,t,encoder_hidden_states=prompt_embeds_edit.detach(),cross_attention_kwargs=cross_attention_kwargs,).sample |
|
|
| loss = 0.0 |
| for name, module in self.unet.named_modules(): |
| module_name = type(module).__name__ |
| if module_name == "CrossAttention" and 'attn2' in name: |
| curr = module.attn_probs |
| ref = d_ref_t2attn[t.item()][name].detach().cuda() |
| loss += ((curr-ref)**2).sum((1,2)).mean(0) |
| loss.backward(retain_graph=False) |
| opt.step() |
|
|
| |
| with torch.no_grad(): |
| noise_pred = self.unet(x_in.detach(),t,encoder_hidden_states=prompt_embeds_edit,cross_attention_kwargs=cross_attention_kwargs,).sample |
| |
| latents = x_in.detach().chunk(2)[0] |
|
|
| |
| if do_classifier_free_guidance: |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
| |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
| |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| progress_bar.update() |
|
|
|
|
| |
| image = self.decode_latents(latents.detach()) |
|
|
| |
| image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
|
|
| |
| image_edit = self.numpy_to_pil(image) |
|
|
|
|
| return image_rec, image_edit |
|
|