| | """ |
| | modeled after the textual_inversion.py / train_dreambooth.py and the work |
| | of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb |
| | """ |
| |
|
| | import inspect |
| | import warnings |
| | from typing import List, Optional, Union |
| |
|
| | import numpy as np |
| | import PIL.Image |
| | import torch |
| | import torch.nn.functional as F |
| | from accelerate import Accelerator |
| |
|
| | |
| | from packaging import version |
| | from tqdm.auto import tqdm |
| | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer |
| |
|
| | from diffusers import DiffusionPipeline |
| | from diffusers.models import AutoencoderKL, UNet2DConditionModel |
| | from diffusers.pipelines.pipeline_utils import StableDiffusionMixin |
| | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
| | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
| | from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler |
| | from diffusers.utils import logging |
| |
|
| |
|
| | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): |
| | PIL_INTERPOLATION = { |
| | "linear": PIL.Image.Resampling.BILINEAR, |
| | "bilinear": PIL.Image.Resampling.BILINEAR, |
| | "bicubic": PIL.Image.Resampling.BICUBIC, |
| | "lanczos": PIL.Image.Resampling.LANCZOS, |
| | "nearest": PIL.Image.Resampling.NEAREST, |
| | } |
| | else: |
| | PIL_INTERPOLATION = { |
| | "linear": PIL.Image.LINEAR, |
| | "bilinear": PIL.Image.BILINEAR, |
| | "bicubic": PIL.Image.BICUBIC, |
| | "lanczos": PIL.Image.LANCZOS, |
| | "nearest": PIL.Image.NEAREST, |
| | } |
| | |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def preprocess(image): |
| | w, h = image.size |
| | w, h = (x - x % 32 for x in (w, h)) |
| | image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) |
| | image = np.array(image).astype(np.float32) / 255.0 |
| | image = image[None].transpose(0, 3, 1, 2) |
| | image = torch.from_numpy(image) |
| | return 2.0 * image - 1.0 |
| |
|
| |
|
| | class ImagicStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): |
| | r""" |
| | Pipeline for imagic image editing. |
| | See paper here: https://arxiv.org/pdf/2210.09276.pdf |
| | |
| | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
| | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
| | Args: |
| | vae ([`AutoencoderKL`]): |
| | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
| | text_encoder ([`CLIPTextModel`]): |
| | Frozen text-encoder. Stable Diffusion uses the text portion of |
| | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically |
| | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. |
| | tokenizer (`CLIPTokenizer`): |
| | Tokenizer of class |
| | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
| | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. |
| | scheduler ([`SchedulerMixin`]): |
| | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
| | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
| | safety_checker ([`StableDiffusionSafetyChecker`]): |
| | Classification module that estimates whether generated images could be considered offsensive or harmful. |
| | Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. |
| | feature_extractor ([`CLIPImageProcessor`]): |
| | Model that extracts features from generated images to be used as inputs for the `safety_checker`. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | vae: AutoencoderKL, |
| | text_encoder: CLIPTextModel, |
| | tokenizer: CLIPTokenizer, |
| | unet: UNet2DConditionModel, |
| | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], |
| | safety_checker: StableDiffusionSafetyChecker, |
| | feature_extractor: CLIPImageProcessor, |
| | ): |
| | super().__init__() |
| | self.register_modules( |
| | vae=vae, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | unet=unet, |
| | scheduler=scheduler, |
| | safety_checker=safety_checker, |
| | feature_extractor=feature_extractor, |
| | ) |
| |
|
| | def train( |
| | self, |
| | prompt: Union[str, List[str]], |
| | image: Union[torch.Tensor, PIL.Image.Image], |
| | height: Optional[int] = 512, |
| | width: Optional[int] = 512, |
| | generator: Optional[torch.Generator] = None, |
| | embedding_learning_rate: float = 0.001, |
| | diffusion_model_learning_rate: float = 2e-6, |
| | text_embedding_optimization_steps: int = 500, |
| | model_fine_tuning_optimization_steps: int = 1000, |
| | **kwargs, |
| | ): |
| | r""" |
| | Function invoked when calling the pipeline for generation. |
| | Args: |
| | prompt (`str` or `List[str]`): |
| | The prompt or prompts to guide the image generation. |
| | height (`int`, *optional*, defaults to 512): |
| | The height in pixels of the generated image. |
| | width (`int`, *optional*, defaults to 512): |
| | The width in pixels of the generated image. |
| | num_inference_steps (`int`, *optional*, defaults to 50): |
| | The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| | expense of slower inference. |
| | guidance_scale (`float`, *optional*, defaults to 7.5): |
| | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
| | `guidance_scale` is defined as `w` of equation 2. of [Imagen |
| | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
| | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
| | usually at the expense of lower image quality. |
| | eta (`float`, *optional*, defaults to 0.0): |
| | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
| | [`schedulers.DDIMScheduler`], will be ignored for others. |
| | generator (`torch.Generator`, *optional*): |
| | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation |
| | deterministic. |
| | latents (`torch.Tensor`, *optional*): |
| | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
| | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
| | tensor will ge generated by sampling using the supplied random `generator`. |
| | output_type (`str`, *optional*, defaults to `"pil"`): |
| | The output format of the generate image. Choose between |
| | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
| | plain tuple. |
| | Returns: |
| | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
| | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. |
| | When returning a tuple, the first element is a list with the generated images, and the second element is a |
| | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" |
| | (nsfw) content, according to the `safety_checker`. |
| | """ |
| | accelerator = Accelerator( |
| | gradient_accumulation_steps=1, |
| | mixed_precision="fp16", |
| | ) |
| |
|
| | if "torch_device" in kwargs: |
| | device = kwargs.pop("torch_device") |
| | warnings.warn( |
| | "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." |
| | " Consider using `pipe.to(torch_device)` instead." |
| | ) |
| |
|
| | if device is None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.to(device) |
| |
|
| | if height % 8 != 0 or width % 8 != 0: |
| | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
| |
|
| | |
| | self.vae.requires_grad_(False) |
| | self.unet.requires_grad_(False) |
| | self.text_encoder.requires_grad_(False) |
| | self.unet.eval() |
| | self.vae.eval() |
| | self.text_encoder.eval() |
| |
|
| | if accelerator.is_main_process: |
| | accelerator.init_trackers( |
| | "imagic", |
| | config={ |
| | "embedding_learning_rate": embedding_learning_rate, |
| | "text_embedding_optimization_steps": text_embedding_optimization_steps, |
| | }, |
| | ) |
| |
|
| | |
| | text_input = self.tokenizer( |
| | prompt, |
| | padding="max_length", |
| | max_length=self.tokenizer.model_max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | text_embeddings = torch.nn.Parameter( |
| | self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True |
| | ) |
| | text_embeddings = text_embeddings.detach() |
| | text_embeddings.requires_grad_() |
| | text_embeddings_orig = text_embeddings.clone() |
| |
|
| | |
| | optimizer = torch.optim.Adam( |
| | [text_embeddings], |
| | lr=embedding_learning_rate, |
| | ) |
| |
|
| | if isinstance(image, PIL.Image.Image): |
| | image = preprocess(image) |
| |
|
| | latents_dtype = text_embeddings.dtype |
| | image = image.to(device=self.device, dtype=latents_dtype) |
| | init_latent_image_dist = self.vae.encode(image).latent_dist |
| | image_latents = init_latent_image_dist.sample(generator=generator) |
| | image_latents = 0.18215 * image_latents |
| |
|
| | progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process) |
| | progress_bar.set_description("Steps") |
| |
|
| | global_step = 0 |
| |
|
| | logger.info("First optimizing the text embedding to better reconstruct the init image") |
| | for _ in range(text_embedding_optimization_steps): |
| | with accelerator.accumulate(text_embeddings): |
| | |
| | noise = torch.randn(image_latents.shape).to(image_latents.device) |
| | timesteps = torch.randint(1000, (1,), device=image_latents.device) |
| |
|
| | |
| | |
| | noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps) |
| |
|
| | |
| | noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample |
| |
|
| | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| | accelerator.backward(loss) |
| |
|
| | optimizer.step() |
| | optimizer.zero_grad() |
| |
|
| | |
| | if accelerator.sync_gradients: |
| | progress_bar.update(1) |
| | global_step += 1 |
| |
|
| | logs = {"loss": loss.detach().item()} |
| | progress_bar.set_postfix(**logs) |
| | accelerator.log(logs, step=global_step) |
| |
|
| | accelerator.wait_for_everyone() |
| |
|
| | text_embeddings.requires_grad_(False) |
| |
|
| | |
| | self.unet.requires_grad_(True) |
| | self.unet.train() |
| | optimizer = torch.optim.Adam( |
| | self.unet.parameters(), |
| | lr=diffusion_model_learning_rate, |
| | ) |
| | progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process) |
| |
|
| | logger.info("Next fine tuning the entire model to better reconstruct the init image") |
| | for _ in range(model_fine_tuning_optimization_steps): |
| | with accelerator.accumulate(self.unet.parameters()): |
| | |
| | noise = torch.randn(image_latents.shape).to(image_latents.device) |
| | timesteps = torch.randint(1000, (1,), device=image_latents.device) |
| |
|
| | |
| | |
| | noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps) |
| |
|
| | |
| | noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample |
| |
|
| | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| | accelerator.backward(loss) |
| |
|
| | optimizer.step() |
| | optimizer.zero_grad() |
| |
|
| | |
| | if accelerator.sync_gradients: |
| | progress_bar.update(1) |
| | global_step += 1 |
| |
|
| | logs = {"loss": loss.detach().item()} |
| | progress_bar.set_postfix(**logs) |
| | accelerator.log(logs, step=global_step) |
| |
|
| | accelerator.wait_for_everyone() |
| | self.text_embeddings_orig = text_embeddings_orig |
| | self.text_embeddings = text_embeddings |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | alpha: float = 1.2, |
| | height: Optional[int] = 512, |
| | width: Optional[int] = 512, |
| | num_inference_steps: Optional[int] = 50, |
| | generator: Optional[torch.Generator] = None, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | guidance_scale: float = 7.5, |
| | eta: float = 0.0, |
| | ): |
| | r""" |
| | Function invoked when calling the pipeline for generation. |
| | Args: |
| | alpha (`float`, *optional*, defaults to 1.2): |
| | The interpolation factor between the original and optimized text embeddings. A value closer to 0 |
| | will resemble the original input image. |
| | height (`int`, *optional*, defaults to 512): |
| | The height in pixels of the generated image. |
| | width (`int`, *optional*, defaults to 512): |
| | The width in pixels of the generated image. |
| | num_inference_steps (`int`, *optional*, defaults to 50): |
| | The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| | expense of slower inference. |
| | guidance_scale (`float`, *optional*, defaults to 7.5): |
| | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
| | `guidance_scale` is defined as `w` of equation 2. of [Imagen |
| | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
| | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
| | usually at the expense of lower image quality. |
| | generator (`torch.Generator`, *optional*): |
| | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation |
| | deterministic. |
| | output_type (`str`, *optional*, defaults to `"pil"`): |
| | The output format of the generate image. Choose between |
| | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
| | plain tuple. |
| | eta (`float`, *optional*, defaults to 0.0): |
| | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
| | [`schedulers.DDIMScheduler`], will be ignored for others. |
| | Returns: |
| | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
| | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. |
| | When returning a tuple, the first element is a list with the generated images, and the second element is a |
| | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" |
| | (nsfw) content, according to the `safety_checker`. |
| | """ |
| | if height % 8 != 0 or width % 8 != 0: |
| | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
| | if self.text_embeddings is None: |
| | raise ValueError("Please run the pipe.train() before trying to generate an image.") |
| | if self.text_embeddings_orig is None: |
| | raise ValueError("Please run the pipe.train() before trying to generate an image.") |
| |
|
| | text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings |
| |
|
| | |
| | |
| | |
| | do_classifier_free_guidance = guidance_scale > 1.0 |
| | |
| | if do_classifier_free_guidance: |
| | uncond_tokens = [""] |
| | max_length = self.tokenizer.model_max_length |
| | uncond_input = self.tokenizer( |
| | uncond_tokens, |
| | padding="max_length", |
| | max_length=max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] |
| |
|
| | |
| | seq_len = uncond_embeddings.shape[1] |
| | uncond_embeddings = uncond_embeddings.view(1, seq_len, -1) |
| |
|
| | |
| | |
| | |
| | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8) |
| | latents_dtype = text_embeddings.dtype |
| | if self.device.type == "mps": |
| | |
| | latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( |
| | self.device |
| | ) |
| | else: |
| | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) |
| |
|
| | |
| | self.scheduler.set_timesteps(num_inference_steps) |
| |
|
| | |
| | |
| | timesteps_tensor = self.scheduler.timesteps.to(self.device) |
| |
|
| | |
| | latents = latents * self.scheduler.init_noise_sigma |
| |
|
| | |
| | |
| | |
| | |
| | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| | extra_step_kwargs = {} |
| | if accepts_eta: |
| | extra_step_kwargs["eta"] = eta |
| |
|
| | for i, t in enumerate(self.progress_bar(timesteps_tensor)): |
| | |
| | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | |
| | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
| |
|
| | latents = 1 / 0.18215 * latents |
| | image = self.vae.decode(latents).sample |
| |
|
| | image = (image / 2 + 0.5).clamp(0, 1) |
| |
|
| | |
| | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| |
|
| | if self.safety_checker is not None: |
| | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( |
| | self.device |
| | ) |
| | image, has_nsfw_concept = self.safety_checker( |
| | images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) |
| | ) |
| | else: |
| | has_nsfw_concept = None |
| |
|
| | if output_type == "pil": |
| | image = self.numpy_to_pil(image) |
| |
|
| | if not return_dict: |
| | return (image, has_nsfw_concept) |
| |
|
| | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
| |
|