| | from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from PIL import Image, ImageFilter |
| |
|
| | from diffusers.image_processor import PipelineImageInput |
| | from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput |
| | from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import ( |
| | StableDiffusionXLImg2ImgPipeline, |
| | rescale_noise_cfg, |
| | retrieve_latents, |
| | retrieve_timesteps, |
| | ) |
| | from diffusers.utils import ( |
| | deprecate, |
| | is_torch_xla_available, |
| | logging, |
| | ) |
| | from diffusers.utils.torch_utils import randn_tensor |
| |
|
| |
|
| | if is_torch_xla_available(): |
| | import torch_xla.core.xla_model as xm |
| |
|
| | XLA_AVAILABLE = True |
| | else: |
| | XLA_AVAILABLE = False |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class MaskedStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline): |
| | debug_save = 0 |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | prompt: Union[str, List[str]] = None, |
| | prompt_2: Optional[Union[str, List[str]]] = None, |
| | image: PipelineImageInput = None, |
| | original_image: PipelineImageInput = None, |
| | strength: float = 0.3, |
| | num_inference_steps: Optional[int] = 50, |
| | timesteps: List[int] = None, |
| | denoising_start: Optional[float] = None, |
| | denoising_end: Optional[float] = None, |
| | guidance_scale: Optional[float] = 5.0, |
| | negative_prompt: Optional[Union[str, List[str]]] = None, |
| | negative_prompt_2: Optional[Union[str, List[str]]] = None, |
| | num_images_per_prompt: Optional[int] = 1, |
| | eta: Optional[float] = 0.0, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | latents: Optional[torch.FloatTensor] = None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | ip_adapter_image: Optional[PipelineImageInput] = None, |
| | ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | guidance_rescale: float = 0.0, |
| | original_size: Tuple[int, int] = None, |
| | crops_coords_top_left: Tuple[int, int] = (0, 0), |
| | target_size: Tuple[int, int] = None, |
| | negative_original_size: Optional[Tuple[int, int]] = None, |
| | negative_crops_coords_top_left: Tuple[int, int] = (0, 0), |
| | negative_target_size: Optional[Tuple[int, int]] = None, |
| | aesthetic_score: float = 6.0, |
| | negative_aesthetic_score: float = 2.5, |
| | clip_skip: Optional[int] = None, |
| | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| | callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| | mask: Union[ |
| | torch.FloatTensor, |
| | Image.Image, |
| | np.ndarray, |
| | List[torch.FloatTensor], |
| | List[Image.Image], |
| | List[np.ndarray], |
| | ] = None, |
| | blur=24, |
| | blur_compose=4, |
| | sample_mode="sample", |
| | **kwargs, |
| | ): |
| | r""" |
| | The call function to the pipeline for generation. |
| | |
| | Args: |
| | prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. |
| | image (`PipelineImageInput`): |
| | `Image` or tensor representing an image batch to be used as the starting point. This image might have mask painted on it. |
| | original_image (`PipelineImageInput`, *optional*): |
| | `Image` or tensor representing an image batch to be used for blending with the result. |
| | strength (`float`, *optional*, defaults to 0.8): |
| | Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a |
| | starting point and more noise is added the higher the `strength`. The number of denoising steps depends |
| | on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising |
| | process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 |
| | essentially ignores `image`. |
| | num_inference_steps (`int`, *optional*, defaults to 50): |
| | The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| | expense of slower inference. This parameter is modulated by `strength`. |
| | guidance_scale (`float`, *optional*, defaults to 7.5): |
| | A higher guidance scale value encourages the model to generate images closely linked to the text |
| | ,`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. |
| | negative_prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts to guide what to not include in image generation. If not defined, you need to |
| | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). |
| | num_images_per_prompt (`int`, *optional*, defaults to 1): |
| | The number of images to generate per prompt. |
| | eta (`float`, *optional*, defaults to 0.0): |
| | Corresponds to parameter eta (Ξ·) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies |
| | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. |
| | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
| | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
| | generation deterministic. |
| | prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not |
| | provided, text embeddings are generated from the `prompt` input argument. |
| | negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If |
| | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. |
| | output_type (`str`, *optional*, defaults to `"pil"`): |
| | The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
| | plain tuple. |
| | callback (`Callable`, *optional*): |
| | A function that calls every `callback_steps` steps during inference. The function is called with the |
| | following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
| | callback_steps (`int`, *optional*, defaults to 1): |
| | The frequency at which the `callback` function is called. If not specified, the callback is called at |
| | every step. |
| | cross_attention_kwargs (`dict`, *optional*): |
| | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in |
| | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
| | blur (`int`, *optional*): |
| | blur to apply to mask |
| | blur_compose (`int`, *optional*): |
| | blur to apply for composition of original a |
| | mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*): |
| | A mask with non-zero elements for the area to be inpainted. If not specified, no mask is applied. |
| | sample_mode (`str`, *optional*): |
| | control latents initialisation for the inpaint area, can be one of sample, argmax, random |
| | Examples: |
| | |
| | Returns: |
| | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
| | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, |
| | otherwise a `tuple` is returned where the first element is a list with the generated images and the |
| | second element is a list of `bool`s indicating whether the corresponding generated image contains |
| | "not-safe-for-work" (nsfw) content. |
| | """ |
| | |
| | callback = kwargs.pop("callback", None) |
| | callback_steps = kwargs.pop("callback_steps", None) |
| |
|
| | if callback is not None: |
| | deprecate( |
| | "callback", |
| | "1.0.0", |
| | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", |
| | ) |
| | if callback_steps is not None: |
| | deprecate( |
| | "callback_steps", |
| | "1.0.0", |
| | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", |
| | ) |
| |
|
| | |
| | self.check_inputs( |
| | prompt, |
| | prompt_2, |
| | strength, |
| | num_inference_steps, |
| | callback_steps, |
| | negative_prompt, |
| | negative_prompt_2, |
| | prompt_embeds, |
| | negative_prompt_embeds, |
| | ip_adapter_image, |
| | ip_adapter_image_embeds, |
| | callback_on_step_end_tensor_inputs, |
| | ) |
| |
|
| | self._guidance_scale = guidance_scale |
| | self._guidance_rescale = guidance_rescale |
| | self._clip_skip = clip_skip |
| | self._cross_attention_kwargs = cross_attention_kwargs |
| | self._denoising_end = denoising_end |
| | self._denoising_start = denoising_start |
| | self._interrupt = False |
| |
|
| | |
| | |
| | if image is not None: |
| | neq = np.any(np.array(original_image) != np.array(image), axis=-1) |
| | mask = neq.astype(np.uint8) * 255 |
| | else: |
| | assert mask is not None |
| |
|
| | if not isinstance(mask, Image.Image): |
| | pil_mask = Image.fromarray(mask) |
| | if pil_mask.mode != "L": |
| | pil_mask = pil_mask.convert("L") |
| | mask_blur = self.blur_mask(pil_mask, blur) |
| | mask_compose = self.blur_mask(pil_mask, blur_compose) |
| | if original_image is None: |
| | original_image = image |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | device = self._execution_device |
| |
|
| | |
| | text_encoder_lora_scale = ( |
| | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None |
| | ) |
| | ( |
| | prompt_embeds, |
| | negative_prompt_embeds, |
| | pooled_prompt_embeds, |
| | negative_pooled_prompt_embeds, |
| | ) = self.encode_prompt( |
| | prompt=prompt, |
| | prompt_2=prompt_2, |
| | device=device, |
| | num_images_per_prompt=num_images_per_prompt, |
| | do_classifier_free_guidance=self.do_classifier_free_guidance, |
| | negative_prompt=negative_prompt, |
| | negative_prompt_2=negative_prompt_2, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | pooled_prompt_embeds=pooled_prompt_embeds, |
| | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
| | lora_scale=text_encoder_lora_scale, |
| | clip_skip=self.clip_skip, |
| | ) |
| |
|
| | |
| | input_image = image if image is not None else original_image |
| | image = self.image_processor.preprocess(input_image) |
| | original_image = self.image_processor.preprocess(original_image) |
| |
|
| | |
| | def denoising_value_valid(dnv): |
| | return isinstance(dnv, float) and 0 < dnv < 1 |
| |
|
| | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) |
| | timesteps, num_inference_steps = self.get_timesteps( |
| | num_inference_steps, |
| | strength, |
| | device, |
| | denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, |
| | ) |
| | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
| |
|
| | add_noise = True if self.denoising_start is None else False |
| |
|
| | |
| | |
| | |
| | latents = self.prepare_latents( |
| | image, |
| | latent_timestep, |
| | batch_size, |
| | num_images_per_prompt, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | add_noise, |
| | sample_mode=sample_mode, |
| | ) |
| |
|
| | |
| | |
| | non_paint_latents = self.prepare_latents( |
| | original_image, |
| | latent_timestep, |
| | batch_size, |
| | num_images_per_prompt, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | add_noise=False, |
| | sample_mode="argmax", |
| | ) |
| |
|
| | if self.debug_save: |
| | init_img_from_latents = self.latents_to_img(non_paint_latents) |
| | init_img_from_latents[0].save("non_paint_latents.png") |
| | |
| | latent_mask = self._make_latent_mask(latents, mask) |
| |
|
| | |
| | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
| |
|
| | height, width = latents.shape[-2:] |
| | height = height * self.vae_scale_factor |
| | width = width * self.vae_scale_factor |
| |
|
| | original_size = original_size or (height, width) |
| | target_size = target_size or (height, width) |
| |
|
| | |
| | if negative_original_size is None: |
| | negative_original_size = original_size |
| | if negative_target_size is None: |
| | negative_target_size = target_size |
| |
|
| | add_text_embeds = pooled_prompt_embeds |
| | if self.text_encoder_2 is None: |
| | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) |
| | else: |
| | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim |
| |
|
| | add_time_ids, add_neg_time_ids = self._get_add_time_ids( |
| | original_size, |
| | crops_coords_top_left, |
| | target_size, |
| | aesthetic_score, |
| | negative_aesthetic_score, |
| | negative_original_size, |
| | negative_crops_coords_top_left, |
| | negative_target_size, |
| | dtype=prompt_embeds.dtype, |
| | text_encoder_projection_dim=text_encoder_projection_dim, |
| | ) |
| | add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) |
| |
|
| | if self.do_classifier_free_guidance: |
| | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
| | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) |
| | add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) |
| | add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) |
| |
|
| | prompt_embeds = prompt_embeds.to(device) |
| | add_text_embeds = add_text_embeds.to(device) |
| | add_time_ids = add_time_ids.to(device) |
| |
|
| | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: |
| | image_embeds = self.prepare_ip_adapter_image_embeds( |
| | ip_adapter_image, |
| | ip_adapter_image_embeds, |
| | device, |
| | batch_size * num_images_per_prompt, |
| | self.do_classifier_free_guidance, |
| | ) |
| |
|
| | |
| | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
| |
|
| | |
| | if ( |
| | self.denoising_end is not None |
| | and self.denoising_start is not None |
| | and denoising_value_valid(self.denoising_end) |
| | and denoising_value_valid(self.denoising_start) |
| | and self.denoising_start >= self.denoising_end |
| | ): |
| | raise ValueError( |
| | f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " |
| | + f" {self.denoising_end} when using type float." |
| | ) |
| | elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): |
| | discrete_timestep_cutoff = int( |
| | round( |
| | self.scheduler.config.num_train_timesteps |
| | - (self.denoising_end * self.scheduler.config.num_train_timesteps) |
| | ) |
| | ) |
| | num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) |
| | timesteps = timesteps[:num_inference_steps] |
| |
|
| | |
| | timestep_cond = None |
| | if self.unet.config.time_cond_proj_dim is not None: |
| | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) |
| | timestep_cond = self.get_guidance_scale_embedding( |
| | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim |
| | ).to(device=device, dtype=latents.dtype) |
| |
|
| | self._num_timesteps = len(timesteps) |
| | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| | for i, t in enumerate(timesteps): |
| | if self.interrupt: |
| | continue |
| |
|
| | shape = non_paint_latents.shape |
| | noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype) |
| | |
| | orig_latents_t = non_paint_latents |
| | orig_latents_t = self.scheduler.add_noise(non_paint_latents, noise, t.unsqueeze(0)) |
| |
|
| | |
| | latents = torch.lerp(orig_latents_t, latents, latent_mask) |
| |
|
| | if self.debug_save: |
| | img1 = self.latents_to_img(latents) |
| | t_str = str(t.int().item()) |
| | for i in range(3 - len(t_str)): |
| | t_str = "0" + t_str |
| | img1[0].save(f"step{t_str}.png") |
| |
|
| | |
| | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} |
| | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: |
| | added_cond_kwargs["image_embeds"] = image_embeds |
| |
|
| | noise_pred = self.unet( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=prompt_embeds, |
| | timestep_cond=timestep_cond, |
| | cross_attention_kwargs=self.cross_attention_kwargs, |
| | added_cond_kwargs=added_cond_kwargs, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | if self.do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: |
| | |
| | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) |
| |
|
| | |
| | latents_dtype = latents.dtype |
| | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
| |
|
| | if latents.dtype != latents_dtype: |
| | if torch.backends.mps.is_available(): |
| | |
| | latents = latents.to(latents_dtype) |
| |
|
| | if callback_on_step_end is not None: |
| | callback_kwargs = {} |
| | for k in callback_on_step_end_tensor_inputs: |
| | callback_kwargs[k] = locals()[k] |
| | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
| |
|
| | latents = callback_outputs.pop("latents", latents) |
| | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
| | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
| | add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) |
| | negative_pooled_prompt_embeds = callback_outputs.pop( |
| | "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds |
| | ) |
| | add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) |
| | add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) |
| |
|
| | |
| | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| | progress_bar.update() |
| | if callback is not None and i % callback_steps == 0: |
| | step_idx = i // getattr(self.scheduler, "order", 1) |
| | callback(step_idx, t, latents) |
| |
|
| | if XLA_AVAILABLE: |
| | xm.mark_step() |
| |
|
| | if not output_type == "latent": |
| | |
| | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast |
| |
|
| | if needs_upcasting: |
| | self.upcast_vae() |
| | elif latents.dtype != self.vae.dtype: |
| | if torch.backends.mps.is_available(): |
| | |
| | self.vae = self.vae.to(latents.dtype) |
| |
|
| | if self.debug_save: |
| | image_gen = self.latents_to_img(latents) |
| | image_gen[0].save("from_latent.png") |
| |
|
| | if latent_mask is not None: |
| | |
| | latents = torch.lerp(non_paint_latents, latents, latent_mask) |
| |
|
| | latents = self.denormalize(latents) |
| | image = self.vae.decode(latents, return_dict=False)[0] |
| | m = mask_compose.permute(2, 0, 1).unsqueeze(0).to(image) |
| | img_compose = m * image + (1 - m) * original_image.to(image) |
| | image = img_compose |
| | |
| | if needs_upcasting: |
| | self.vae.to(dtype=torch.float16) |
| | else: |
| | image = latents |
| |
|
| | |
| | if self.watermark is not None: |
| | image = self.watermark.apply_watermark(image) |
| |
|
| | image = self.image_processor.postprocess(image, output_type=output_type) |
| |
|
| | |
| | self.maybe_free_model_hooks() |
| |
|
| | if not return_dict: |
| | return (image,) |
| |
|
| | return StableDiffusionXLPipelineOutput(images=image) |
| |
|
| | def _make_latent_mask(self, latents, mask): |
| | if mask is not None: |
| | latent_mask = [] |
| | if not isinstance(mask, list): |
| | tmp_mask = [mask] |
| | else: |
| | tmp_mask = mask |
| | _, l_channels, l_height, l_width = latents.shape |
| | for m in tmp_mask: |
| | if not isinstance(m, Image.Image): |
| | if len(m.shape) == 2: |
| | m = m[..., np.newaxis] |
| | if m.max() > 1: |
| | m = m / 255.0 |
| | m = self.image_processor.numpy_to_pil(m)[0] |
| | if m.mode != "L": |
| | m = m.convert("L") |
| | resized = self.image_processor.resize(m, l_height, l_width) |
| | if self.debug_save: |
| | resized.save("latent_mask.png") |
| | latent_mask.append(np.repeat(np.array(resized)[np.newaxis, :, :], l_channels, axis=0)) |
| | latent_mask = torch.as_tensor(np.stack(latent_mask)).to(latents) |
| | latent_mask = latent_mask / max(latent_mask.max(), 1) |
| | return latent_mask |
| |
|
| | def prepare_latents( |
| | self, |
| | image, |
| | timestep, |
| | batch_size, |
| | num_images_per_prompt, |
| | dtype, |
| | device, |
| | generator=None, |
| | add_noise=True, |
| | sample_mode: str = "sample", |
| | ): |
| | if not isinstance(image, (torch.Tensor, Image.Image, list)): |
| | raise ValueError( |
| | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" |
| | ) |
| |
|
| | |
| | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
| | self.text_encoder_2.to("cpu") |
| | torch.cuda.empty_cache() |
| |
|
| | image = image.to(device=device, dtype=dtype) |
| |
|
| | batch_size = batch_size * num_images_per_prompt |
| |
|
| | if image.shape[1] == 4: |
| | init_latents = image |
| | elif sample_mode == "random": |
| | height, width = image.shape[-2:] |
| | num_channels_latents = self.unet.config.in_channels |
| | latents = self.random_latents( |
| | batch_size, |
| | num_channels_latents, |
| | height, |
| | width, |
| | dtype, |
| | device, |
| | generator, |
| | ) |
| | return self.vae.config.scaling_factor * latents |
| | else: |
| | |
| | if self.vae.config.force_upcast: |
| | image = image.float() |
| | self.vae.to(dtype=torch.float32) |
| |
|
| | if isinstance(generator, list) and len(generator) != batch_size: |
| | raise ValueError( |
| | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| | f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
| | ) |
| |
|
| | elif isinstance(generator, list): |
| | init_latents = [ |
| | retrieve_latents( |
| | self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode |
| | ) |
| | for i in range(batch_size) |
| | ] |
| | init_latents = torch.cat(init_latents, dim=0) |
| | else: |
| | init_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode=sample_mode) |
| |
|
| | if self.vae.config.force_upcast: |
| | self.vae.to(dtype) |
| |
|
| | init_latents = init_latents.to(dtype) |
| | init_latents = self.vae.config.scaling_factor * init_latents |
| |
|
| | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: |
| | |
| | additional_image_per_prompt = batch_size // init_latents.shape[0] |
| | init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) |
| | elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: |
| | raise ValueError( |
| | f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
| | ) |
| | else: |
| | init_latents = torch.cat([init_latents], dim=0) |
| |
|
| | if add_noise: |
| | shape = init_latents.shape |
| | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| | |
| | init_latents = self.scheduler.add_noise(init_latents, noise, timestep) |
| |
|
| | latents = init_latents |
| |
|
| | return latents |
| |
|
| | |
| | def random_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
| | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
| | if isinstance(generator, list) and len(generator) != batch_size: |
| | raise ValueError( |
| | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| | f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
| | ) |
| |
|
| | if latents is None: |
| | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| | else: |
| | latents = latents.to(device) |
| |
|
| | |
| | latents = latents * self.scheduler.init_noise_sigma |
| | return latents |
| |
|
| | def denormalize(self, latents): |
| | |
| | |
| | has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None |
| | has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None |
| | if has_latents_mean and has_latents_std: |
| | latents_mean = ( |
| | torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) |
| | ) |
| | latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) |
| | latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean |
| | else: |
| | latents = latents / self.vae.config.scaling_factor |
| |
|
| | return latents |
| |
|
| | def latents_to_img(self, latents): |
| | l1 = self.denormalize(latents) |
| | img1 = self.vae.decode(l1, return_dict=False)[0] |
| | img1 = self.image_processor.postprocess(img1, output_type="pil", do_denormalize=[True]) |
| | return img1 |
| |
|
| | def blur_mask(self, pil_mask, blur): |
| | mask_blur = pil_mask.filter(ImageFilter.GaussianBlur(radius=blur)) |
| | mask_blur = np.array(mask_blur) |
| | return torch.from_numpy(np.tile(mask_blur / mask_blur.max(), (3, 1, 1)).transpose(1, 2, 0)) |
| |
|