| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import inspect |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import PIL.Image |
| import torch |
| import torch.nn.functional as F |
| import torchsde |
|
|
| from transformers import ( |
| CLIPImageProcessor, |
| CLIPTextModel, |
| CLIPTextModelWithProjection, |
| CLIPTokenizer, |
| CLIPVisionModelWithProjection, |
| ) |
|
|
| from diffusers.utils.import_utils import is_invisible_watermark_available |
|
|
| from diffusers.image_processor import PipelineImageInput, VaeImageProcessor |
| from diffusers.loaders import ( |
| FromSingleFileMixin, |
| IPAdapterMixin, |
| StableDiffusionXLLoraLoaderMixin, |
| TextualInversionLoaderMixin, |
| ) |
| from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel |
| from diffusers.models.attention_processor import ( |
| AttnProcessor2_0, |
| LoRAAttnProcessor2_0, |
| LoRAXFormersAttnProcessor, |
| XFormersAttnProcessor, |
| ) |
| from diffusers.models.lora import adjust_lora_scale_text_encoder |
| from diffusers.schedulers import KarrasDiffusionSchedulers |
| from diffusers.utils import ( |
| USE_PEFT_BACKEND, |
| deprecate, |
| logging, |
| replace_example_docstring, |
| scale_lora_layers, |
| unscale_lora_layers, |
| ) |
| from diffusers.utils.torch_utils import is_compiled_module, randn_tensor |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin |
| from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput |
|
|
|
|
| if is_invisible_watermark_available(): |
| from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker |
|
|
| from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| EXAMPLE_DOC_STRING = """ |
| Examples: |
| ```py |
| >>> # pip install accelerate transformers safetensors diffusers |
| |
| >>> import torch |
| >>> import numpy as np |
| >>> from PIL import Image |
| |
| >>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation |
| >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL |
| >>> from diffusers.utils import load_image |
| |
| |
| >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") |
| >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") |
| >>> controlnet = ControlNetModel.from_pretrained( |
| ... "diffusers/controlnet-depth-sdxl-1.0-small", |
| ... variant="fp16", |
| ... use_safetensors=True, |
| ... torch_dtype=torch.float16, |
| ... ).to("cuda") |
| >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda") |
| >>> pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( |
| ... "stabilityai/stable-diffusion-xl-base-1.0", |
| ... controlnet=controlnet, |
| ... vae=vae, |
| ... variant="fp16", |
| ... use_safetensors=True, |
| ... torch_dtype=torch.float16, |
| ... ).to("cuda") |
| >>> pipe.enable_model_cpu_offload() |
| |
| |
| >>> def get_depth_map(image): |
| ... image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") |
| ... with torch.no_grad(), torch.autocast("cuda"): |
| ... depth_map = depth_estimator(image).predicted_depth |
| |
| ... depth_map = torch.nn.functional.interpolate( |
| ... depth_map.unsqueeze(1), |
| ... size=(1024, 1024), |
| ... mode="bicubic", |
| ... align_corners=False, |
| ... ) |
| ... depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) |
| ... depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) |
| ... depth_map = (depth_map - depth_min) / (depth_max - depth_min) |
| ... image = torch.cat([depth_map] * 3, dim=1) |
| ... image = image.permute(0, 2, 3, 1).cpu().numpy()[0] |
| ... image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) |
| ... return image |
| |
| |
| >>> prompt = "A robot, 4k photo" |
| >>> image = load_image( |
| ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" |
| ... "/kandinsky/cat.png" |
| ... ).resize((1024, 1024)) |
| >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization |
| >>> depth_image = get_depth_map(image) |
| |
| >>> images = pipe( |
| ... prompt, |
| ... image=image, |
| ... control_image=depth_image, |
| ... strength=0.99, |
| ... num_inference_steps=50, |
| ... controlnet_conditioning_scale=controlnet_conditioning_scale, |
| ... ).images |
| >>> images[0].save(f"robot_cat.png") |
| ``` |
| """ |
|
|
|
|
| |
| def retrieve_latents( |
| encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
| ): |
| if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
| return encoder_output.latent_dist.sample(generator) |
| elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
| return encoder_output.latent_dist.mode() |
| elif hasattr(encoder_output, "latents"): |
| return encoder_output.latents |
| else: |
| raise AttributeError("Could not access latents of provided encoder_output") |
|
|
| class BatchedBrownianTree: |
| """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" |
|
|
| def __init__(self, x, t0, t1, seed=None, **kwargs): |
| self.cpu_tree = True |
| if "cpu" in kwargs: |
| self.cpu_tree = kwargs.pop("cpu") |
| t0, t1, self.sign = self.sort(t0, t1) |
| w0 = kwargs.get('w0', torch.zeros_like(x)) |
| if seed is None: |
| seed = torch.randint(0, 2 ** 63 - 1, []).item() |
| self.batched = True |
| try: |
| assert len(seed) == x.shape[0] |
| w0 = w0[0] |
| except TypeError: |
| seed = [seed] |
| self.batched = False |
| if self.cpu_tree: |
| self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] |
| else: |
| self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] |
|
|
| @staticmethod |
| def sort(a, b): |
| return (a, b, 1) if a < b else (b, a, -1) |
|
|
| def __call__(self, t0, t1): |
| t0, t1, sign = self.sort(t0, t1) |
| if self.cpu_tree: |
| w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign) |
| else: |
| w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) |
|
|
| return w if self.batched else w[0] |
|
|
|
|
| class BrownianTreeNoiseSampler: |
| """A noise sampler backed by a torchsde.BrownianTree. |
| |
| Args: |
| x (Tensor): The tensor whose shape, device and dtype to use to generate |
| random samples. |
| sigma_min (float): The low end of the valid interval. |
| sigma_max (float): The high end of the valid interval. |
| seed (int or List[int]): The random seed. If a list of seeds is |
| supplied instead of a single integer, then the noise sampler will |
| use one BrownianTree per batch item, each with its own seed. |
| transform (callable): A function that maps sigma to the sampler's |
| internal timestep. |
| """ |
|
|
| def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): |
| self.transform = transform |
| t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) |
| self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) |
|
|
| def __call__(self, sigma, sigma_next): |
| t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) |
| return self.tree(t0, t1) / (t1 - t0).abs().sqrt() |
|
|
|
|
| class OmniZeroPipeline( |
| DiffusionPipeline, |
| StableDiffusionMixin, |
| TextualInversionLoaderMixin, |
| StableDiffusionXLLoraLoaderMixin, |
| FromSingleFileMixin, |
| IPAdapterMixin, |
| ): |
| r""" |
| Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. |
| |
| This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
| library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
| |
| The pipeline also inherits the following loading methods: |
| - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings |
| - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights |
| - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights |
| - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters |
| |
| Args: |
| vae ([`AutoencoderKL`]): |
| Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
| text_encoder ([`CLIPTextModel`]): |
| Frozen text-encoder. Stable Diffusion uses the text portion of |
| [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically |
| the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. |
| text_encoder_2 ([` CLIPTextModelWithProjection`]): |
| Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of |
| [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), |
| specifically the |
| [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) |
| variant. |
| tokenizer (`CLIPTokenizer`): |
| Tokenizer of class |
| [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
| tokenizer_2 (`CLIPTokenizer`): |
| Second Tokenizer of class |
| [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
| unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. |
| controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): |
| Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets |
| as a list, the outputs from each ControlNet are added together to create one combined additional |
| conditioning. |
| scheduler ([`SchedulerMixin`]): |
| A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
| [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
| requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): |
| Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the |
| config of `stabilityai/stable-diffusion-xl-refiner-1-0`. |
| force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): |
| Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of |
| `stabilityai/stable-diffusion-xl-base-1-0`. |
| add_watermarker (`bool`, *optional*): |
| Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to |
| watermark output images. If not defined, it will default to True if the package is installed, otherwise no |
| watermarker will be used. |
| feature_extractor ([`~transformers.CLIPImageProcessor`]): |
| A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. |
| """ |
|
|
| model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" |
| _optional_components = [ |
| "tokenizer", |
| "tokenizer_2", |
| "text_encoder", |
| "text_encoder_2", |
| "feature_extractor", |
| "image_encoder", |
| ] |
| _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] |
|
|
| def __init__( |
| self, |
| vae: AutoencoderKL, |
| text_encoder: CLIPTextModel, |
| text_encoder_2: CLIPTextModelWithProjection, |
| tokenizer: CLIPTokenizer, |
| tokenizer_2: CLIPTokenizer, |
| unet: UNet2DConditionModel, |
| controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], |
| scheduler: KarrasDiffusionSchedulers, |
| requires_aesthetics_score: bool = False, |
| force_zeros_for_empty_prompt: bool = True, |
| add_watermarker: Optional[bool] = None, |
| feature_extractor: CLIPImageProcessor = None, |
| image_encoder: CLIPVisionModelWithProjection = None, |
| ): |
| super().__init__() |
|
|
| if isinstance(controlnet, (list, tuple)): |
| controlnet = MultiControlNetModel(controlnet) |
|
|
| self.register_modules( |
| vae=vae, |
| text_encoder=text_encoder, |
| text_encoder_2=text_encoder_2, |
| tokenizer=tokenizer, |
| tokenizer_2=tokenizer_2, |
| unet=unet, |
| controlnet=controlnet, |
| scheduler=scheduler, |
| feature_extractor=feature_extractor, |
| image_encoder=image_encoder, |
| ) |
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) |
| self.control_image_processor = VaeImageProcessor( |
| vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False |
| ) |
| add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() |
|
|
| if add_watermarker: |
| self.watermark = StableDiffusionXLWatermarker() |
| else: |
| self.watermark = None |
|
|
| self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) |
| self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) |
|
|
| self.ays_noise_sigmas = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582], |
| "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582], |
| "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]} |
| |
| @staticmethod |
| def _loglinear_interp(t_steps, num_steps): |
| xs = np.linspace(0, 1, len(t_steps)) |
| ys = np.log(t_steps[::-1]) |
|
|
| new_xs = np.linspace(0, 1, num_steps) |
| new_ys = np.interp(new_xs, xs, ys) |
|
|
| return np.exp(new_ys)[::-1].copy() |
| |
| |
| def encode_prompt( |
| self, |
| prompt: str, |
| prompt_2: Optional[str] = None, |
| device: Optional[torch.device] = None, |
| num_images_per_prompt: int = 1, |
| do_classifier_free_guidance: bool = True, |
| negative_prompt: Optional[str] = None, |
| negative_prompt_2: Optional[str] = None, |
| prompt_embeds: Optional[torch.FloatTensor] = None, |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| lora_scale: Optional[float] = None, |
| clip_skip: Optional[int] = None, |
| ): |
| r""" |
| Encodes the prompt into text encoder hidden states. |
| |
| Args: |
| prompt (`str` or `List[str]`, *optional*): |
| prompt to be encoded |
| prompt_2 (`str` or `List[str]`, *optional*): |
| The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is |
| used in both text-encoders |
| device: (`torch.device`): |
| torch device |
| num_images_per_prompt (`int`): |
| number of images that should be generated per prompt |
| do_classifier_free_guidance (`bool`): |
| whether to use classifier free guidance or not |
| negative_prompt (`str` or `List[str]`, *optional*): |
| The prompt or prompts not to guide the image generation. If not defined, one has to pass |
| `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
| less than `1`). |
| negative_prompt_2 (`str` or `List[str]`, *optional*): |
| The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and |
| `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders |
| prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
| provided, text embeddings will be generated from `prompt` input argument. |
| negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
| weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
| argument. |
| pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. |
| If not provided, pooled text embeddings will be generated from `prompt` input argument. |
| negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
| weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` |
| input argument. |
| lora_scale (`float`, *optional*): |
| A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. |
| clip_skip (`int`, *optional*): |
| Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that |
| the output of the pre-final layer will be used for computing the prompt embeddings. |
| """ |
| device = device or self._execution_device |
|
|
| |
| |
| if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): |
| self._lora_scale = lora_scale |
|
|
| |
| if self.text_encoder is not None: |
| if not USE_PEFT_BACKEND: |
| adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) |
| else: |
| scale_lora_layers(self.text_encoder, lora_scale) |
|
|
| if self.text_encoder_2 is not None: |
| if not USE_PEFT_BACKEND: |
| adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) |
| else: |
| scale_lora_layers(self.text_encoder_2, lora_scale) |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
| if prompt is not None: |
| batch_size = len(prompt) |
| else: |
| batch_size = prompt_embeds.shape[0] |
|
|
| |
| tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] |
| text_encoders = ( |
| [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] |
| ) |
|
|
| if prompt_embeds is None: |
| prompt_2 = prompt_2 or prompt |
| prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 |
|
|
| |
| prompt_embeds_list = [] |
| prompts = [prompt, prompt_2] |
| for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): |
| if isinstance(self, TextualInversionLoaderMixin): |
| prompt = self.maybe_convert_prompt(prompt, tokenizer) |
|
|
| text_inputs = tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
|
|
| text_input_ids = text_inputs.input_ids |
| untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids |
|
|
| if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
| text_input_ids, untruncated_ids |
| ): |
| removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) |
| logger.warning( |
| "The following part of your input was truncated because CLIP can only handle sequences up to" |
| f" {tokenizer.model_max_length} tokens: {removed_text}" |
| ) |
|
|
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) |
|
|
| |
| pooled_prompt_embeds = prompt_embeds[0] |
| if clip_skip is None: |
| prompt_embeds = prompt_embeds.hidden_states[-2] |
| else: |
| |
| prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] |
|
|
| prompt_embeds_list.append(prompt_embeds) |
|
|
| prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) |
|
|
| |
| zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt |
| if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: |
| negative_prompt_embeds = torch.zeros_like(prompt_embeds) |
| negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) |
| elif do_classifier_free_guidance and negative_prompt_embeds is None: |
| negative_prompt = negative_prompt or "" |
| negative_prompt_2 = negative_prompt_2 or negative_prompt |
|
|
| |
| negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
| negative_prompt_2 = ( |
| batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 |
| ) |
|
|
| uncond_tokens: List[str] |
| if prompt is not None and type(prompt) is not type(negative_prompt): |
| raise TypeError( |
| f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
| f" {type(prompt)}." |
| ) |
| elif batch_size != len(negative_prompt): |
| raise ValueError( |
| f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
| f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
| " the batch size of `prompt`." |
| ) |
| else: |
| uncond_tokens = [negative_prompt, negative_prompt_2] |
|
|
| negative_prompt_embeds_list = [] |
| for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): |
| if isinstance(self, TextualInversionLoaderMixin): |
| negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) |
|
|
| max_length = prompt_embeds.shape[1] |
| uncond_input = tokenizer( |
| negative_prompt, |
| padding="max_length", |
| max_length=max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
|
|
| negative_prompt_embeds = text_encoder( |
| uncond_input.input_ids.to(device), |
| output_hidden_states=True, |
| ) |
| |
| negative_pooled_prompt_embeds = negative_prompt_embeds[0] |
| negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] |
|
|
| negative_prompt_embeds_list.append(negative_prompt_embeds) |
|
|
| negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) |
|
|
| if self.text_encoder_2 is not None: |
| prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) |
| else: |
| prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) |
|
|
| bs_embed, seq_len, _ = prompt_embeds.shape |
| |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) |
|
|
| if do_classifier_free_guidance: |
| |
| seq_len = negative_prompt_embeds.shape[1] |
|
|
| if self.text_encoder_2 is not None: |
| negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) |
| else: |
| negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) |
|
|
| negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
| pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( |
| bs_embed * num_images_per_prompt, -1 |
| ) |
| if do_classifier_free_guidance: |
| negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( |
| bs_embed * num_images_per_prompt, -1 |
| ) |
|
|
| if self.text_encoder is not None: |
| if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: |
| |
| unscale_lora_layers(self.text_encoder, lora_scale) |
|
|
| if self.text_encoder_2 is not None: |
| if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: |
| |
| unscale_lora_layers(self.text_encoder_2, lora_scale) |
|
|
| return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds |
|
|
| |
| def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None, unconditional_noising_factor=1.0): |
| dtype = next(self.image_encoder.parameters()).dtype |
|
|
| needs_encoding = not isinstance(image, torch.Tensor) |
| if needs_encoding: |
| image = self.feature_extractor(image, return_tensors="pt").pixel_values |
|
|
| image = image.to(device=device, dtype=dtype) |
|
|
| avg_image = torch.mean(image, dim=0, keepdim=True).to(dtype=torch.float32) |
| seed = int(torch.sum(avg_image).item()) % 1000000007 |
| torch.manual_seed(seed) |
| additional_noise_for_uncond = torch.rand_like(image) * unconditional_noising_factor |
|
|
| if output_hidden_states: |
| if needs_encoding: |
| image_encoded = self.image_encoder(image, output_hidden_states=True) |
| image_enc_hidden_states = image_encoded.hidden_states[-2] |
| else: |
| image_enc_hidden_states = image.unsqueeze(0).unsqueeze(0) |
| image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) |
| |
| if needs_encoding: |
| uncond_image_encoded = self.image_encoder(additional_noise_for_uncond, output_hidden_states=True) |
| uncond_image_enc_hidden_states = uncond_image_encoded.hidden_states[-2] |
| else: |
| uncond_image_enc_hidden_states = additional_noise_for_uncond.unsqueeze(0).unsqueeze(0) |
| uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( |
| num_images_per_prompt, dim=0 |
| ) |
| return image_enc_hidden_states, uncond_image_enc_hidden_states |
| else: |
| if needs_encoding: |
| image_encoded = self.image_encoder(image) |
| image_embeds = image_encoded.image_embeds |
| else: |
| image_embeds = image.unsqueeze(0).unsqueeze(0) |
| if needs_encoding: |
| uncond_image_encoded = self.image_encoder(additional_noise_for_uncond) |
| uncond_image_embeds = uncond_image_encoded.image_embeds |
| else: |
| uncond_image_embeds = additional_noise_for_uncond.unsqueeze(0).unsqueeze(0) |
|
|
| image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) |
| uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) |
|
|
| return image_embeds, uncond_image_embeds |
|
|
| |
| def prepare_ip_adapter_image_embeds( |
| self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance |
| ): |
| if ip_adapter_image_embeds is None: |
| if not isinstance(ip_adapter_image, list): |
| ip_adapter_image = [ip_adapter_image] |
|
|
| if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): |
| raise ValueError( |
| f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." |
| ) |
|
|
| image_embeds = [] |
| for single_ip_adapter_image, image_proj_layer in zip( |
| ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers |
| ): |
| output_hidden_state = not isinstance(image_proj_layer, ImageProjection) |
| single_image_embeds, single_negative_image_embeds = self.encode_image( |
| single_ip_adapter_image, device, 1, output_hidden_state |
| ) |
| single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) |
| single_negative_image_embeds = torch.stack( |
| [single_negative_image_embeds] * num_images_per_prompt, dim=0 |
| ) |
|
|
| if do_classifier_free_guidance: |
| single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) |
| single_image_embeds = single_image_embeds.to(device) |
|
|
| image_embeds.append(single_image_embeds) |
| else: |
| repeat_dims = [1] |
| image_embeds = [] |
| for single_image_embeds in ip_adapter_image_embeds: |
| if do_classifier_free_guidance: |
| single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) |
| single_image_embeds = single_image_embeds.repeat( |
| num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) |
| ) |
| single_negative_image_embeds = single_negative_image_embeds.repeat( |
| num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) |
| ) |
| single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) |
| else: |
| single_image_embeds = single_image_embeds.repeat( |
| num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) |
| ) |
| image_embeds.append(single_image_embeds) |
|
|
| return image_embeds |
|
|
| |
| def prepare_extra_step_kwargs(self, generator, eta): |
| |
| |
| |
| |
|
|
| accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| extra_step_kwargs = {} |
| if accepts_eta: |
| extra_step_kwargs["eta"] = eta |
|
|
| |
| accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| if accepts_generator: |
| extra_step_kwargs["generator"] = generator |
| return extra_step_kwargs |
|
|
| def check_inputs( |
| self, |
| prompt, |
| prompt_2, |
| image, |
| strength, |
| num_inference_steps, |
| callback_steps, |
| negative_prompt=None, |
| negative_prompt_2=None, |
| prompt_embeds=None, |
| negative_prompt_embeds=None, |
| pooled_prompt_embeds=None, |
| negative_pooled_prompt_embeds=None, |
| ip_adapter_image=None, |
| ip_adapter_image_embeds=None, |
| controlnet_conditioning_scale=1.0, |
| control_guidance_start=0.0, |
| control_guidance_end=1.0, |
| callback_on_step_end_tensor_inputs=None, |
| ): |
| if strength < 0 or strength > 1: |
| raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") |
| if num_inference_steps is None: |
| raise ValueError("`num_inference_steps` cannot be None.") |
| elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: |
| raise ValueError( |
| f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" |
| f" {type(num_inference_steps)}." |
| ) |
|
|
| if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): |
| raise ValueError( |
| f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
| f" {type(callback_steps)}." |
| ) |
|
|
| if callback_on_step_end_tensor_inputs is not None and not all( |
| k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs |
| ): |
| raise ValueError( |
| f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" |
| ) |
|
|
| if prompt is not None and prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
| " only forward one of the two." |
| ) |
| elif prompt_2 is not None and prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
| " only forward one of the two." |
| ) |
| elif prompt is None and prompt_embeds is None: |
| raise ValueError( |
| "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." |
| ) |
| elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
| raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
| elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): |
| raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") |
|
|
| if negative_prompt is not None and negative_prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" |
| f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
| ) |
| elif negative_prompt_2 is not None and negative_prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" |
| f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
| ) |
|
|
| if prompt_embeds is not None and negative_prompt_embeds is not None: |
| if prompt_embeds.shape != negative_prompt_embeds.shape: |
| raise ValueError( |
| "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" |
| f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" |
| f" {negative_prompt_embeds.shape}." |
| ) |
|
|
| if prompt_embeds is not None and pooled_prompt_embeds is None: |
| raise ValueError( |
| "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." |
| ) |
|
|
| if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: |
| raise ValueError( |
| "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." |
| ) |
|
|
| |
| |
| if isinstance(self.controlnet, MultiControlNetModel): |
| if isinstance(prompt, list): |
| logger.warning( |
| f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" |
| " prompts. The conditionings will be fixed across the prompts." |
| ) |
|
|
| |
| is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( |
| self.controlnet, torch._dynamo.eval_frame.OptimizedModule |
| ) |
| if ( |
| isinstance(self.controlnet, ControlNetModel) |
| or is_compiled |
| and isinstance(self.controlnet._orig_mod, ControlNetModel) |
| ): |
| self.check_image(image, prompt, prompt_embeds) |
| elif ( |
| isinstance(self.controlnet, MultiControlNetModel) |
| or is_compiled |
| and isinstance(self.controlnet._orig_mod, MultiControlNetModel) |
| ): |
| if not isinstance(image, list): |
| raise TypeError("For multiple controlnets: `image` must be type `list`") |
|
|
| |
| |
| elif any(isinstance(i, list) for i in image): |
| raise ValueError("A single batch of multiple conditionings are supported at the moment.") |
| elif len(image) != len(self.controlnet.nets): |
| raise ValueError( |
| f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." |
| ) |
|
|
| for image_ in image: |
| self.check_image(image_, prompt, prompt_embeds) |
| else: |
| assert False |
|
|
| |
| if ( |
| isinstance(self.controlnet, ControlNetModel) |
| or is_compiled |
| and isinstance(self.controlnet._orig_mod, ControlNetModel) |
| ): |
| if not isinstance(controlnet_conditioning_scale, float): |
| raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") |
| elif ( |
| isinstance(self.controlnet, MultiControlNetModel) |
| or is_compiled |
| and isinstance(self.controlnet._orig_mod, MultiControlNetModel) |
| ): |
| if isinstance(controlnet_conditioning_scale, list): |
| if any(isinstance(i, list) for i in controlnet_conditioning_scale): |
| raise ValueError("A single batch of multiple conditionings are supported at the moment.") |
| elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( |
| self.controlnet.nets |
| ): |
| raise ValueError( |
| "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" |
| " the same length as the number of controlnets" |
| ) |
| else: |
| assert False |
|
|
| if not isinstance(control_guidance_start, (tuple, list)): |
| control_guidance_start = [control_guidance_start] |
|
|
| if not isinstance(control_guidance_end, (tuple, list)): |
| control_guidance_end = [control_guidance_end] |
|
|
| if len(control_guidance_start) != len(control_guidance_end): |
| raise ValueError( |
| f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." |
| ) |
|
|
| if isinstance(self.controlnet, MultiControlNetModel): |
| if len(control_guidance_start) != len(self.controlnet.nets): |
| raise ValueError( |
| f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." |
| ) |
|
|
| for start, end in zip(control_guidance_start, control_guidance_end): |
| if start >= end: |
| raise ValueError( |
| f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." |
| ) |
| if start < 0.0: |
| raise ValueError(f"control guidance start: {start} can't be smaller than 0.") |
| if end > 1.0: |
| raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") |
|
|
| if ip_adapter_image is not None and ip_adapter_image_embeds is not None: |
| raise ValueError( |
| "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." |
| ) |
|
|
| if ip_adapter_image_embeds is not None: |
| if not isinstance(ip_adapter_image_embeds, list): |
| raise ValueError( |
| f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" |
| ) |
| elif ip_adapter_image_embeds[0].ndim not in [3, 4]: |
| raise ValueError( |
| f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" |
| ) |
|
|
| |
| def check_image(self, image, prompt, prompt_embeds): |
| image_is_pil = isinstance(image, PIL.Image.Image) |
| image_is_tensor = isinstance(image, torch.Tensor) |
| image_is_np = isinstance(image, np.ndarray) |
| image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) |
| image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) |
| image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) |
|
|
| if ( |
| not image_is_pil |
| and not image_is_tensor |
| and not image_is_np |
| and not image_is_pil_list |
| and not image_is_tensor_list |
| and not image_is_np_list |
| ): |
| raise TypeError( |
| f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" |
| ) |
|
|
| if image_is_pil: |
| image_batch_size = 1 |
| else: |
| image_batch_size = len(image) |
|
|
| if prompt is not None and isinstance(prompt, str): |
| prompt_batch_size = 1 |
| elif prompt is not None and isinstance(prompt, list): |
| prompt_batch_size = len(prompt) |
| elif prompt_embeds is not None: |
| prompt_batch_size = prompt_embeds.shape[0] |
|
|
| if image_batch_size != 1 and image_batch_size != prompt_batch_size: |
| raise ValueError( |
| f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" |
| ) |
|
|
| |
| def prepare_control_image( |
| self, |
| image, |
| width, |
| height, |
| batch_size, |
| num_images_per_prompt, |
| device, |
| dtype, |
| do_classifier_free_guidance=False, |
| guess_mode=False, |
| ): |
| image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) |
| image_batch_size = image.shape[0] |
|
|
| if image_batch_size == 1: |
| repeat_by = batch_size |
| else: |
| |
| repeat_by = num_images_per_prompt |
|
|
| image = image.repeat_interleave(repeat_by, dim=0) |
|
|
| image = image.to(device=device, dtype=dtype) |
|
|
| if do_classifier_free_guidance and not guess_mode: |
| image = torch.cat([image] * 2) |
|
|
| return image |
|
|
| |
| def get_timesteps(self, num_inference_steps, strength, device): |
| |
| init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
|
|
| t_start = max(num_inference_steps - init_timestep, 0) |
| timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] |
| if hasattr(self.scheduler, "set_begin_index"): |
| self.scheduler.set_begin_index(t_start * self.scheduler.order) |
|
|
| return timesteps, num_inference_steps - t_start |
|
|
| |
| def prepare_latents( |
| self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator=None, add_noise=True, seed=None |
| ): |
|
|
| if image is None: |
| shape = ( |
| batch_size, |
| num_channels_latents, |
| int(height) // self.vae_scale_factor, |
| int(width) // self.vae_scale_factor, |
| ) |
| init_latents = torch.zeros(shape, device=device, dtype=dtype) |
| else: |
| if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): |
| raise ValueError( |
| f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" |
| ) |
|
|
| latents_mean = latents_std = None |
| if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: |
| latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) |
| if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: |
| latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) |
|
|
| |
| if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
| self.text_encoder_2.to("cpu") |
| torch.cuda.empty_cache() |
| |
| image = image.to(device=device, dtype=dtype) |
|
|
| if image.shape[1] == 4: |
| init_latents = image |
|
|
| else: |
| |
| if self.vae.config.force_upcast: |
| image = image.float() |
| self.vae.to(dtype=torch.float32) |
|
|
| if isinstance(generator, list) and len(generator) != batch_size: |
| raise ValueError( |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
| ) |
|
|
| elif isinstance(generator, list): |
| init_latents = [ |
| retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) |
| for i in range(batch_size) |
| ] |
| init_latents = torch.cat(init_latents, dim=0) |
| else: |
| init_latents = retrieve_latents(self.vae.encode(image), generator=generator) |
|
|
| if self.vae.config.force_upcast: |
| self.vae.to(dtype) |
|
|
| init_latents = init_latents.to(dtype) |
| if latents_mean is not None and latents_std is not None: |
| latents_mean = latents_mean.to(device=self.device, dtype=dtype) |
| latents_std = latents_std.to(device=self.device, dtype=dtype) |
| init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std |
| else: |
| init_latents = self.vae.config.scaling_factor * init_latents |
|
|
| if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: |
| |
| additional_image_per_prompt = batch_size // init_latents.shape[0] |
| init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) |
| elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: |
| raise ValueError( |
| f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
| ) |
| else: |
| init_latents = torch.cat([init_latents], dim=0) |
|
|
| if add_noise: |
| if seed is not None: |
| generator = torch.manual_seed(seed) |
| noise = torch.randn(torch.Size(init_latents.shape), dtype=torch.float32, layout=torch.strided, generator=generator, device="cpu").to(device) |
| init_latents = self.scheduler.add_noise(init_latents.to(device), noise, timestep) |
| return init_latents.to(device, dtype=dtype) |
|
|
| latents = init_latents |
|
|
| return latents |
|
|
| |
| def _get_add_time_ids( |
| self, |
| original_size, |
| crops_coords_top_left, |
| target_size, |
| aesthetic_score, |
| negative_aesthetic_score, |
| negative_original_size, |
| negative_crops_coords_top_left, |
| negative_target_size, |
| dtype, |
| text_encoder_projection_dim=None, |
| ): |
| if self.config.requires_aesthetics_score: |
| add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) |
| add_neg_time_ids = list( |
| negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) |
| ) |
| else: |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) |
| add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) |
|
|
| passed_add_embed_dim = ( |
| self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim |
| ) |
| expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features |
|
|
| if ( |
| expected_add_embed_dim > passed_add_embed_dim |
| and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim |
| ): |
| raise ValueError( |
| f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." |
| ) |
| elif ( |
| expected_add_embed_dim < passed_add_embed_dim |
| and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim |
| ): |
| raise ValueError( |
| f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." |
| ) |
| elif expected_add_embed_dim != passed_add_embed_dim: |
| raise ValueError( |
| f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." |
| ) |
|
|
| add_time_ids = torch.tensor([add_time_ids], dtype=dtype) |
| add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) |
|
|
| return add_time_ids, add_neg_time_ids |
|
|
| |
| def upcast_vae(self): |
| dtype = self.vae.dtype |
| self.vae.to(dtype=torch.float32) |
| use_torch_2_0_or_xformers = isinstance( |
| self.vae.decoder.mid_block.attentions[0].processor, |
| ( |
| AttnProcessor2_0, |
| XFormersAttnProcessor, |
| LoRAXFormersAttnProcessor, |
| LoRAAttnProcessor2_0, |
| ), |
| ) |
| |
| |
| if use_torch_2_0_or_xformers: |
| self.vae.post_quant_conv.to(dtype) |
| self.vae.decoder.conv_in.to(dtype) |
| self.vae.decoder.mid_block.to(dtype) |
|
|
| @property |
| def guidance_scale(self): |
| return self._guidance_scale |
|
|
| @property |
| def clip_skip(self): |
| return self._clip_skip |
|
|
| |
| |
| |
| @property |
| def do_classifier_free_guidance(self): |
| return self._guidance_scale > 1 |
|
|
| @property |
| def cross_attention_kwargs(self): |
| return self._cross_attention_kwargs |
|
|
| @property |
| def num_timesteps(self): |
| return self._num_timesteps |
|
|
| @torch.no_grad() |
| @replace_example_docstring(EXAMPLE_DOC_STRING) |
| def __call__( |
| self, |
| prompt: Union[str, List[str]] = None, |
| prompt_2: Optional[Union[str, List[str]]] = None, |
| image: PipelineImageInput = None, |
| control_image: PipelineImageInput = None, |
| control_mask = None, |
| identity_control_indices = None, |
| height: Optional[int] = None, |
| width: Optional[int] = None, |
| strength: float = 0.8, |
| num_inference_steps: int = 50, |
| timesteps: Optional[List[int]] = None, |
| sigmas: Optional[List[float]] = None, |
| guidance_scale: float = 5.0, |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| negative_prompt_2: Optional[Union[str, List[str]]] = None, |
| num_images_per_prompt: Optional[int] = 1, |
| eta: float = 0.0, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| seed: Optional[int] = None, |
| latents: Optional[torch.FloatTensor] = None, |
| prompt_embeds: Optional[torch.FloatTensor] = None, |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| ip_adapter_image: Optional[PipelineImageInput] = None, |
| ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, |
| output_type: Optional[str] = "pil", |
| return_dict: bool = True, |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| controlnet_conditioning_scale: Union[float, List[float]] = 0.8, |
| guess_mode: bool = False, |
| control_guidance_start: Union[float, List[float]] = 0.0, |
| control_guidance_end: Union[float, List[float]] = 1.0, |
| original_size: Tuple[int, int] = None, |
| crops_coords_top_left: Tuple[int, int] = (0, 0), |
| target_size: Tuple[int, int] = None, |
| negative_original_size: Optional[Tuple[int, int]] = None, |
| negative_crops_coords_top_left: Tuple[int, int] = (0, 0), |
| negative_target_size: Optional[Tuple[int, int]] = None, |
| aesthetic_score: float = 6.0, |
| negative_aesthetic_score: float = 2.5, |
| clip_skip: Optional[int] = None, |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| **kwargs, |
| ): |
| r""" |
| Function invoked when calling the pipeline for generation. |
| |
| Args: |
| prompt (`str` or `List[str]`, *optional*): |
| The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
| instead. |
| prompt_2 (`str` or `List[str]`, *optional*): |
| The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is |
| used in both text-encoders |
| image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: |
| `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): |
| The initial image will be used as the starting point for the image generation process. Can also accept |
| image latents as `image`, if passing latents directly, it will not be encoded again. |
| control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: |
| `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): |
| The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If |
| the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can |
| also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If |
| height and/or width are passed, `image` is resized according to them. If multiple ControlNets are |
| specified in init, images must be passed as a list such that each element of the list can be correctly |
| batched for input to a single controlnet. |
| height (`int`, *optional*, defaults to the size of control_image): |
| The height in pixels of the generated image. Anything below 512 pixels won't work well for |
| [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) |
| and checkpoints that are not specifically fine-tuned on low resolutions. |
| width (`int`, *optional*, defaults to the size of control_image): |
| The width in pixels of the generated image. Anything below 512 pixels won't work well for |
| [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) |
| and checkpoints that are not specifically fine-tuned on low resolutions. |
| strength (`float`, *optional*, defaults to 0.8): |
| Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a |
| starting point and more noise is added the higher the `strength`. The number of denoising steps depends |
| on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising |
| process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 |
| essentially ignores `image`. |
| num_inference_steps (`int`, *optional*, defaults to 50): |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| expense of slower inference. |
| guidance_scale (`float`, *optional*, defaults to 7.5): |
| Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
| `guidance_scale` is defined as `w` of equation 2. of [Imagen |
| Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
| 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
| usually at the expense of lower image quality. |
| negative_prompt (`str` or `List[str]`, *optional*): |
| The prompt or prompts not to guide the image generation. If not defined, one has to pass |
| `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
| less than `1`). |
| negative_prompt_2 (`str` or `List[str]`, *optional*): |
| The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and |
| `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders |
| num_images_per_prompt (`int`, *optional*, defaults to 1): |
| The number of images to generate per prompt. |
| eta (`float`, *optional*, defaults to 0.0): |
| Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
| [`schedulers.DDIMScheduler`], will be ignored for others. |
| generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
| One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
| to make generation deterministic. |
| latents (`torch.FloatTensor`, *optional*): |
| Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
| tensor will ge generated by sampling using the supplied random `generator`. |
| prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
| provided, text embeddings will be generated from `prompt` input argument. |
| negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
| weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
| argument. |
| pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. |
| If not provided, pooled text embeddings will be generated from `prompt` input argument. |
| negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
| Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
| weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` |
| input argument. |
| ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. |
| ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): |
| Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of |
| IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should |
| contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not |
| provided, embeddings are computed from the `ip_adapter_image` input argument. |
| output_type (`str`, *optional*, defaults to `"pil"`): |
| The output format of the generate image. Choose between |
| [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
| plain tuple. |
| cross_attention_kwargs (`dict`, *optional*): |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
| `self.processor` in |
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
| controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): |
| The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added |
| to the residual in the original unet. If multiple ControlNets are specified in init, you can set the |
| corresponding scale as a list. |
| guess_mode (`bool`, *optional*, defaults to `False`): |
| In this mode, the ControlNet encoder will try best to recognize the content of the input image even if |
| you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. |
| control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): |
| The percentage of total steps at which the controlnet starts applying. |
| control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): |
| The percentage of total steps at which the controlnet stops applying. |
| original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): |
| If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. |
| `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as |
| explained in section 2.2 of |
| [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
| crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): |
| `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position |
| `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting |
| `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of |
| [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
| target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): |
| For most cases, `target_size` should be set to the desired height and width of the generated image. If |
| not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in |
| section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
| negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): |
| To negatively condition the generation process based on a specific image resolution. Part of SDXL's |
| micro-conditioning as explained in section 2.2 of |
| [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more |
| information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. |
| negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): |
| To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's |
| micro-conditioning as explained in section 2.2 of |
| [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more |
| information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. |
| negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): |
| To negatively condition the generation process based on a target image resolution. It should be as same |
| as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of |
| [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more |
| information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. |
| aesthetic_score (`float`, *optional*, defaults to 6.0): |
| Used to simulate an aesthetic score of the generated image by influencing the positive text condition. |
| Part of SDXL's micro-conditioning as explained in section 2.2 of |
| [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
| negative_aesthetic_score (`float`, *optional*, defaults to 2.5): |
| Part of SDXL's micro-conditioning as explained in section 2.2 of |
| [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to |
| simulate an aesthetic score of the generated image by influencing the negative text condition. |
| clip_skip (`int`, *optional*): |
| Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that |
| the output of the pre-final layer will be used for computing the prompt embeddings. |
| callback_on_step_end (`Callable`, *optional*): |
| A function that calls at the end of each denoising steps during the inference. The function is called |
| with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, |
| callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by |
| `callback_on_step_end_tensor_inputs`. |
| callback_on_step_end_tensor_inputs (`List`, *optional*): |
| The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list |
| will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the |
| `._callback_tensor_inputs` attribute of your pipeline class. |
| |
| Examples: |
| |
| Returns: |
| [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
| [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` |
| containing the output images. |
| """ |
|
|
| callback = kwargs.pop("callback", None) |
| callback_steps = kwargs.pop("callback_steps", None) |
|
|
| if callback is not None: |
| deprecate( |
| "callback", |
| "1.0.0", |
| "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", |
| ) |
| if callback_steps is not None: |
| deprecate( |
| "callback_steps", |
| "1.0.0", |
| "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", |
| ) |
|
|
| controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet |
|
|
| |
| if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): |
| control_guidance_start = len(control_guidance_end) * [control_guidance_start] |
| elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): |
| control_guidance_end = len(control_guidance_start) * [control_guidance_end] |
| elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): |
| mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 |
| control_guidance_start, control_guidance_end = ( |
| mult * [control_guidance_start], |
| mult * [control_guidance_end], |
| ) |
|
|
| |
| self.check_inputs( |
| prompt, |
| prompt_2, |
| control_image, |
| strength, |
| num_inference_steps, |
| callback_steps, |
| negative_prompt, |
| negative_prompt_2, |
| prompt_embeds, |
| negative_prompt_embeds, |
| pooled_prompt_embeds, |
| negative_pooled_prompt_embeds, |
| ip_adapter_image, |
| ip_adapter_image_embeds, |
| controlnet_conditioning_scale, |
| control_guidance_start, |
| control_guidance_end, |
| callback_on_step_end_tensor_inputs, |
| ) |
|
|
| self._guidance_scale = guidance_scale |
| self._clip_skip = clip_skip |
| self._cross_attention_kwargs = cross_attention_kwargs |
|
|
| |
| if prompt is not None and isinstance(prompt, str): |
| batch_size = 1 |
| elif prompt is not None and isinstance(prompt, list): |
| batch_size = len(prompt) |
| else: |
| batch_size = prompt_embeds.shape[0] |
|
|
| device = self._execution_device |
|
|
| if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): |
| controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) |
|
|
| global_pool_conditions = ( |
| controlnet.config.global_pool_conditions |
| if isinstance(controlnet, ControlNetModel) |
| else controlnet.nets[0].config.global_pool_conditions |
| ) |
| guess_mode = guess_mode or global_pool_conditions |
|
|
| |
| text_encoder_lora_scale = ( |
| self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None |
| ) |
| ( |
| prompt_embeds, |
| negative_prompt_embeds, |
| pooled_prompt_embeds, |
| negative_pooled_prompt_embeds, |
| ) = self.encode_prompt( |
| prompt, |
| prompt_2, |
| device, |
| num_images_per_prompt, |
| self.do_classifier_free_guidance, |
| negative_prompt, |
| negative_prompt_2, |
| prompt_embeds=prompt_embeds, |
| negative_prompt_embeds=negative_prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
| lora_scale=text_encoder_lora_scale, |
| clip_skip=self.clip_skip, |
| ) |
|
|
| |
| if ip_adapter_image is not None or ip_adapter_image_embeds is not None: |
| image_embeds = self.prepare_ip_adapter_image_embeds( |
| ip_adapter_image, |
| ip_adapter_image_embeds, |
| device, |
| batch_size * num_images_per_prompt, |
| self.do_classifier_free_guidance, |
| ) |
|
|
| |
| if image is not None: |
| image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) |
| else: |
| strength = 1.0 |
|
|
| if isinstance(controlnet, ControlNetModel): |
| control_image = self.prepare_control_image( |
| image=control_image, |
| width=width, |
| height=height, |
| batch_size=batch_size * num_images_per_prompt, |
| num_images_per_prompt=num_images_per_prompt, |
| device=device, |
| dtype=controlnet.dtype, |
| do_classifier_free_guidance=self.do_classifier_free_guidance, |
| guess_mode=guess_mode, |
| ) |
| height, width = control_image.shape[-2:] |
| elif isinstance(controlnet, MultiControlNetModel): |
| control_images = [] |
|
|
| for control_image_ in control_image: |
| control_image_ = self.prepare_control_image( |
| image=control_image_, |
| width=width, |
| height=height, |
| batch_size=batch_size * num_images_per_prompt, |
| num_images_per_prompt=num_images_per_prompt, |
| device=device, |
| dtype=controlnet.dtype, |
| do_classifier_free_guidance=self.do_classifier_free_guidance, |
| guess_mode=guess_mode, |
| ) |
|
|
| control_images.append(control_image_) |
|
|
| control_image = control_images |
| height, width = control_image[0].shape[-2:] |
| else: |
| assert False |
|
|
| |
| controlnet_masks = [] |
| if control_mask is not None: |
| for mask in control_mask: |
| mask = np.array(mask) |
| mask_tensor = torch.from_numpy(mask).to(device=device, dtype=prompt_embeds.dtype) |
| mask_tensor = mask_tensor[:, :, 0] / 255. |
| mask_tensor = mask_tensor[None, None] |
| h, w = mask_tensor.shape[-2:] |
| control_mask_list = [] |
| for scale in [8, 8, 8, 16, 16, 16, 32, 32, 32]: |
| |
| w_n = round((w + 0.01) / 8) |
| h_n = round((h + 0.01) / 8) |
| if scale in [16, 32]: |
| w_n = round((w_n + 0.01) / 2) |
| h_n = round((h_n + 0.01) / 2) |
| if scale == 32: |
| w_n = round((w_n + 0.01) / 2) |
| h_n = round((h_n + 0.01) / 2) |
| scale_mask_weight_image_tensor = F.interpolate( |
| mask_tensor,(h_n, w_n), mode='bilinear') |
| control_mask_list.append(scale_mask_weight_image_tensor) |
| controlnet_masks.append(control_mask_list) |
|
|
| |
| full_num_inference_steps = int(num_inference_steps / strength) if strength > 0 else num_inference_steps |
|
|
| if timesteps is None: |
| self.scheduler.set_timesteps(full_num_inference_steps + 1, device=device) |
| sigmas = self._loglinear_interp(self.ays_noise_sigmas["SDXL"], full_num_inference_steps + 1) |
| sigmas[-1] = 0 |
| log_sigmas = np.log(np.array((1 - self.scheduler.alphas_cumprod) / self.scheduler.alphas_cumprod) ** 0.5) |
| timesteps = np.array([self.scheduler._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() |
| timesteps = timesteps[-(num_inference_steps + 1):-1] |
| if hasattr(self.scheduler, "sigmas"): |
| self.scheduler.sigmas = torch.from_numpy(sigmas)[-(num_inference_steps + 1):] |
| self.scheduler.timesteps = torch.from_numpy(timesteps).to(self.device, dtype=torch.int64) |
| self.scheduler.num_inference_steps = len(self.scheduler.timesteps) |
|
|
| else: |
| if "timesteps" in inspect.signature(self.scheduler.set_timesteps).parameters: |
| self.scheduler.set_timesteps(full_num_inference_steps + 1, timesteps=timesteps, device=device) |
| else: |
| self.scheduler.set_timesteps(full_num_inference_steps + 1, device=device) |
|
|
| latent_timestep = self.scheduler.timesteps[:1].repeat(batch_size * num_images_per_prompt) |
| self._num_timesteps = len(self.scheduler.timesteps) |
|
|
| |
| if latents is None: |
| num_channels_latents = self.unet.config.in_channels |
| latents = self.prepare_latents( |
| image, |
| latent_timestep, |
| batch_size * num_images_per_prompt, |
| num_channels_latents, |
| height, |
| width, |
| prompt_embeds.dtype, |
| device, |
| generator, |
| True, |
| seed |
| ) |
|
|
| if hasattr(self.scheduler, "sigmas"): |
| sigmas = self.scheduler.sigmas |
| sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() |
| seeds = [seed] * len(latents) if seed is not None else generator.seed() |
| brownian_tree_noise_sampler = BrownianTreeNoiseSampler(latents, sigma_min, sigma_max, seed=seeds, cpu=False) |
| else: |
| brownian_tree_noise_sampler = None |
|
|
| |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
| |
| controlnet_keep = [] |
| for i in range(len(timesteps)): |
| keeps = [ |
| 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) |
| for s, e in zip(control_guidance_start, control_guidance_end) |
| ] |
| controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) |
|
|
| |
| if isinstance(control_image, list): |
| original_size = original_size or control_image[0].shape[-2:] |
| else: |
| original_size = original_size or control_image.shape[-2:] |
| target_size = target_size or (height, width) |
|
|
| if negative_original_size is None: |
| negative_original_size = original_size |
| if negative_target_size is None: |
| negative_target_size = target_size |
| add_text_embeds = pooled_prompt_embeds |
|
|
| if self.text_encoder_2 is None: |
| text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) |
| else: |
| text_encoder_projection_dim = self.text_encoder_2.config.projection_dim |
|
|
| add_time_ids, add_neg_time_ids = self._get_add_time_ids( |
| original_size, |
| crops_coords_top_left, |
| target_size, |
| aesthetic_score, |
| negative_aesthetic_score, |
| negative_original_size, |
| negative_crops_coords_top_left, |
| negative_target_size, |
| dtype=prompt_embeds.dtype, |
| text_encoder_projection_dim=text_encoder_projection_dim, |
| ) |
| add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) |
|
|
| if self.do_classifier_free_guidance: |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
| add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) |
| add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) |
| add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) |
|
|
| prompt_embeds = prompt_embeds.to(device) |
| add_text_embeds = add_text_embeds.to(device) |
| add_time_ids = add_time_ids.to(device) |
|
|
| |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| with self.progress_bar(total=num_inference_steps) as progress_bar: |
| for i, t in enumerate(timesteps): |
| |
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
| added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} |
|
|
| |
| if guess_mode and self.do_classifier_free_guidance: |
| |
| control_model_input = latents |
| control_model_input = self.scheduler.scale_model_input(control_model_input, t) |
| controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] |
| controlnet_added_cond_kwargs = { |
| "text_embeds": add_text_embeds.chunk(2)[1], |
| "time_ids": add_time_ids.chunk(2)[1], |
| } |
| else: |
| control_model_input = latent_model_input |
| controlnet_prompt_embeds = prompt_embeds |
| controlnet_added_cond_kwargs = added_cond_kwargs |
|
|
| if isinstance(controlnet_keep[i], list): |
| cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] |
| else: |
| controlnet_cond_scale = controlnet_conditioning_scale |
| if isinstance(controlnet_cond_scale, list): |
| controlnet_cond_scale = controlnet_cond_scale[0] |
| cond_scale = controlnet_cond_scale * controlnet_keep[i] |
| |
| if ip_adapter_image_embeds is None and ip_adapter_image is not None: |
| encoder_hidden_states = self.unet.process_encoder_hidden_states(prompt_embeds, {"image_embeds": image_embeds}) |
| ip_adapter_image_embeds = encoder_hidden_states[1] |
|
|
| down_block_res_samples = None |
| mid_block_res_sample = None |
|
|
| for controlnet_index in range(len(self.controlnet.nets)): |
| ip_adapter_index = next((y for x, y in identity_control_indices if x == controlnet_index), None) |
| if ip_adapter_index is not None: |
| control_prompt_embeds = ip_adapter_image_embeds[ip_adapter_index].squeeze(1) |
| else: |
| control_prompt_embeds = controlnet_prompt_embeds |
| down_samples, mid_sample = self.controlnet.nets[controlnet_index]( |
| control_model_input, |
| t, |
| encoder_hidden_states=control_prompt_embeds, |
| controlnet_cond=control_image[controlnet_index], |
| conditioning_scale=cond_scale[controlnet_index], |
| guess_mode=guess_mode, |
| added_cond_kwargs=controlnet_added_cond_kwargs, |
| return_dict=False, |
| ) |
|
|
| if len(controlnet_masks) > controlnet_index and controlnet_masks[controlnet_index] is not None: |
| down_samples = [ |
| down_sample * mask_weight |
| for down_sample, mask_weight in zip(down_samples, controlnet_masks[controlnet_index]) |
| ] |
| mid_sample *= controlnet_masks[controlnet_index][-1] |
|
|
| if down_block_res_samples is None and mid_block_res_sample is None: |
| down_block_res_samples = down_samples |
| mid_block_res_sample = mid_sample |
| else: |
| down_block_res_samples = [ |
| samples_prev + samples_curr |
| for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) |
| ] |
| mid_block_res_sample += mid_sample |
|
|
| if guess_mode and self.do_classifier_free_guidance: |
| |
| |
| |
| down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] |
| mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) |
|
|
| if ip_adapter_image is not None or ip_adapter_image_embeds is not None: |
| added_cond_kwargs["image_embeds"] = image_embeds |
|
|
| |
| noise_pred = self.unet( |
| latent_model_input, |
| t, |
| encoder_hidden_states=prompt_embeds, |
| cross_attention_kwargs=self.cross_attention_kwargs, |
| down_block_additional_residuals=down_block_res_samples, |
| mid_block_additional_residual=mid_block_res_sample, |
| added_cond_kwargs=added_cond_kwargs, |
| return_dict=False, |
| )[0] |
|
|
| |
| if self.do_classifier_free_guidance: |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
| |
| |
| if "variance_noise" in inspect.signature(self.scheduler.step).parameters and brownian_tree_noise_sampler is not None: |
| sigmas = self.scheduler.sigmas |
| noise = brownian_tree_noise_sampler(sigmas[i], sigmas[i + 1]).to(device=device, dtype=latents.dtype) |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False, variance_noise=noise)[0] |
| else: |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
|
|
| if callback_on_step_end is not None: |
| callback_kwargs = {} |
| for k in callback_on_step_end_tensor_inputs: |
| callback_kwargs[k] = locals()[k] |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
| latents = callback_outputs.pop("latents", latents) |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
| negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
|
|
| |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| progress_bar.update() |
| if callback is not None and i % callback_steps == 0: |
| step_idx = i // getattr(self.scheduler, "order", 1) |
| callback(step_idx, t, latents) |
|
|
| |
| |
| if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
| self.unet.to("cpu") |
| self.controlnet.to("cpu") |
| torch.cuda.empty_cache() |
|
|
| if not output_type == "latent": |
| |
| needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast |
|
|
| if needs_upcasting: |
| self.upcast_vae() |
| latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) |
|
|
| |
| |
| has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None |
| has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None |
| if has_latents_mean and has_latents_std: |
| latents_mean = ( |
| torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) |
| ) |
| latents_std = ( |
| torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) |
| ) |
| latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean |
| else: |
| latents = latents / self.vae.config.scaling_factor |
|
|
| image = self.vae.decode(latents, return_dict=False)[0] |
|
|
| |
| if needs_upcasting: |
| self.vae.to(dtype=torch.float16) |
| else: |
| image = latents |
| return StableDiffusionXLPipelineOutput(images=image) |
|
|
| |
| if self.watermark is not None: |
| image = self.watermark.apply_watermark(image) |
|
|
| image = self.image_processor.postprocess(image, output_type=output_type) |
|
|
| |
| self.maybe_free_model_hooks() |
|
|
| if not return_dict: |
| return (image,) |
|
|
| return StableDiffusionXLPipelineOutput(images=image) |
|
|