| import inspect
|
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
| import numpy as np
|
| import PIL.Image
|
| import torch
|
| import torch.nn.functional as F
|
| from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
|
|
| from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
| from diffusers.models import AutoencoderKL, ImageProjection
|
| from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| from diffusers.schedulers import KarrasDiffusionSchedulers
|
| from diffusers.utils import (
|
| USE_PEFT_BACKEND,
|
| deprecate,
|
| logging,
|
| replace_example_docstring,
|
| scale_lora_layers,
|
| unscale_lora_layers,
|
| )
|
| from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
| from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| from model import UNet2DConditionModelEx
|
|
|
|
|
| from huggingface_hub.utils import validate_hf_hub_args
|
|
|
|
|
| logger = logging.get_logger(__name__)
|
|
|
|
|
| EXAMPLE_DOC_STRING = """
|
| Examples:
|
| ```py
|
| >>> # !pip install opencv-python transformers accelerate
|
| >>> from diffusers import UniPCMultistepScheduler
|
| >>> from diffusers.utils import load_image
|
| >>> from model import UNet2DConditionModelEx
|
| >>> from pipeline import StableDiffusionControlLoraV3Pipeline
|
| >>> import numpy as np
|
| >>> import torch
|
|
|
| >>> import cv2
|
| >>> from PIL import Image
|
|
|
| >>> # download an image
|
| >>> image = load_image(
|
| ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
|
| ... )
|
| >>> image = np.array(image)
|
|
|
| >>> # get canny image
|
| >>> image = cv2.Canny(image, 100, 200)
|
| >>> image = image[:, :, None]
|
| >>> image = np.concatenate([image, image, image], axis=2)
|
| >>> canny_image = Image.fromarray(image)
|
|
|
| >>> # load stable diffusion v1-5 and control-lora-v3
|
| >>> unet: UNet2DConditionModelEx = UNet2DConditionModelEx.from_pretrained(
|
| ... "runwayml/stable-diffusion-v1-5", subfolder="unet", torch_dtype=torch.float16
|
| ... )
|
| >>> unet = unet.add_extra_conditions(["canny"])
|
| >>> pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
|
| ... "runwayml/stable-diffusion-v1-5", unet=unet, torch_dtype=torch.float16
|
| ... )
|
| >>> # load attention processors
|
| >>> pipe.load_lora_weights("HighCWu/sd-control-lora-v3-canny")
|
|
|
| >>> # speed up diffusion process with faster scheduler and memory optimization
|
| >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| >>> # remove following line if xformers is not installed
|
| >>> pipe.enable_xformers_memory_efficient_attention()
|
|
|
| >>> pipe.enable_model_cpu_offload()
|
|
|
| >>> # generate image
|
| >>> generator = torch.manual_seed(0)
|
| >>> image = pipe(
|
| ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
|
| ... ).images[0]
|
| ```
|
| """
|
|
|
|
|
|
|
| def retrieve_timesteps(
|
| scheduler,
|
| num_inference_steps: Optional[int] = None,
|
| device: Optional[Union[str, torch.device]] = None,
|
| timesteps: Optional[List[int]] = None,
|
| sigmas: Optional[List[float]] = None,
|
| **kwargs,
|
| ):
|
| """
|
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
|
|
| Args:
|
| scheduler (`SchedulerMixin`):
|
| The scheduler to get timesteps from.
|
| num_inference_steps (`int`):
|
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| must be `None`.
|
| device (`str` or `torch.device`, *optional*):
|
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| timesteps (`List[int]`, *optional*):
|
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| `num_inference_steps` and `sigmas` must be `None`.
|
| sigmas (`List[float]`, *optional*):
|
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| `num_inference_steps` and `timesteps` must be `None`.
|
|
|
| Returns:
|
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| second element is the number of inference steps.
|
| """
|
| if timesteps is not None and sigmas is not None:
|
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| if timesteps is not None:
|
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| if not accepts_timesteps:
|
| raise ValueError(
|
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| f" timestep schedules. Please check whether you are using the correct scheduler."
|
| )
|
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| timesteps = scheduler.timesteps
|
| num_inference_steps = len(timesteps)
|
| elif sigmas is not None:
|
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| if not accept_sigmas:
|
| raise ValueError(
|
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| )
|
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| timesteps = scheduler.timesteps
|
| num_inference_steps = len(timesteps)
|
| else:
|
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| timesteps = scheduler.timesteps
|
| return timesteps, num_inference_steps
|
|
|
|
|
| class StableDiffusionControlLoraV3Pipeline(
|
| DiffusionPipeline,
|
| StableDiffusionMixin,
|
| TextualInversionLoaderMixin,
|
| LoraLoaderMixin,
|
| IPAdapterMixin,
|
| FromSingleFileMixin,
|
| ):
|
| r"""
|
| Pipeline for text-to-image generation using Stable Diffusion with extra condition guidance.
|
|
|
| This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
|
|
| The pipeline also inherits the following loading methods:
|
| - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
|
|
| Args:
|
| vae ([`AutoencoderKL`]):
|
| Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| text_encoder ([`~transformers.CLIPTextModel`]):
|
| Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| tokenizer ([`~transformers.CLIPTokenizer`]):
|
| A `CLIPTokenizer` to tokenize text.
|
| unet ([`UNet2DConditionModelEx`]):
|
| A `UNet2DConditionModelEx` to denoise the encoded image latents with extra conditions.
|
| scheduler ([`SchedulerMixin`]):
|
| A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| safety_checker ([`StableDiffusionSafetyChecker`]):
|
| Classification module that estimates whether generated images could be considered offensive or harmful.
|
| Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
| about a model's potential harms.
|
| feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
| """
|
|
|
| model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
| _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
| _exclude_from_cpu_offload = ["safety_checker"]
|
| _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
|
|
| def __init__(
|
| self,
|
| vae: AutoencoderKL,
|
| text_encoder: CLIPTextModel,
|
| tokenizer: CLIPTokenizer,
|
| unet: UNet2DConditionModelEx,
|
| scheduler: KarrasDiffusionSchedulers,
|
| safety_checker: StableDiffusionSafetyChecker,
|
| feature_extractor: CLIPImageProcessor,
|
| image_encoder: CLIPVisionModelWithProjection = None,
|
| requires_safety_checker: bool = True,
|
| ):
|
| super().__init__()
|
|
|
| if safety_checker is None and requires_safety_checker:
|
| logger.warning(
|
| f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| " results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| )
|
|
|
| if safety_checker is not None and feature_extractor is None:
|
| raise ValueError(
|
| "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| )
|
|
|
| self.register_modules(
|
| vae=vae,
|
| text_encoder=text_encoder,
|
| tokenizer=tokenizer,
|
| unet=unet,
|
| scheduler=scheduler,
|
| safety_checker=safety_checker,
|
| feature_extractor=feature_extractor,
|
| image_encoder=image_encoder,
|
| )
|
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
| self.register_to_config(requires_safety_checker=requires_safety_checker)
|
|
|
|
|
| def _encode_prompt(
|
| self,
|
| prompt,
|
| device,
|
| num_images_per_prompt,
|
| do_classifier_free_guidance,
|
| negative_prompt=None,
|
| prompt_embeds: Optional[torch.Tensor] = None,
|
| negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| lora_scale: Optional[float] = None,
|
| **kwargs,
|
| ):
|
| deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
| deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
|
|
| prompt_embeds_tuple = self.encode_prompt(
|
| prompt=prompt,
|
| device=device,
|
| num_images_per_prompt=num_images_per_prompt,
|
| do_classifier_free_guidance=do_classifier_free_guidance,
|
| negative_prompt=negative_prompt,
|
| prompt_embeds=prompt_embeds,
|
| negative_prompt_embeds=negative_prompt_embeds,
|
| lora_scale=lora_scale,
|
| **kwargs,
|
| )
|
|
|
|
|
| prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
|
|
| return prompt_embeds
|
|
|
|
|
| def encode_prompt(
|
| self,
|
| prompt,
|
| device,
|
| num_images_per_prompt,
|
| do_classifier_free_guidance,
|
| negative_prompt=None,
|
| prompt_embeds: Optional[torch.Tensor] = None,
|
| negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| lora_scale: Optional[float] = None,
|
| clip_skip: Optional[int] = None,
|
| ):
|
| r"""
|
| Encodes the prompt into text encoder hidden states.
|
|
|
| Args:
|
| prompt (`str` or `List[str]`, *optional*):
|
| prompt to be encoded
|
| device: (`torch.device`):
|
| torch device
|
| num_images_per_prompt (`int`):
|
| number of images that should be generated per prompt
|
| do_classifier_free_guidance (`bool`):
|
| whether to use classifier free guidance or not
|
| negative_prompt (`str` or `List[str]`, *optional*):
|
| The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| less than `1`).
|
| prompt_embeds (`torch.Tensor`, *optional*):
|
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| provided, text embeddings will be generated from `prompt` input argument.
|
| negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| argument.
|
| lora_scale (`float`, *optional*):
|
| A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| clip_skip (`int`, *optional*):
|
| Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| the output of the pre-final layer will be used for computing the prompt embeddings.
|
| """
|
|
|
|
|
| if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
| self._lora_scale = lora_scale
|
|
|
|
|
| if not USE_PEFT_BACKEND:
|
| adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| else:
|
| scale_lora_layers(self.text_encoder, lora_scale)
|
|
|
| if prompt is not None and isinstance(prompt, str):
|
| batch_size = 1
|
| elif prompt is not None and isinstance(prompt, list):
|
| batch_size = len(prompt)
|
| else:
|
| batch_size = prompt_embeds.shape[0]
|
|
|
| if prompt_embeds is None:
|
|
|
| if isinstance(self, TextualInversionLoaderMixin):
|
| prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
|
|
| text_inputs = self.tokenizer(
|
| prompt,
|
| padding="max_length",
|
| max_length=self.tokenizer.model_max_length,
|
| truncation=True,
|
| return_tensors="pt",
|
| )
|
| text_input_ids = text_inputs.input_ids
|
| untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
|
|
| if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| text_input_ids, untruncated_ids
|
| ):
|
| removed_text = self.tokenizer.batch_decode(
|
| untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| )
|
| logger.warning(
|
| "The following part of your input was truncated because CLIP can only handle sequences up to"
|
| f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| )
|
|
|
| if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| attention_mask = text_inputs.attention_mask.to(device)
|
| else:
|
| attention_mask = None
|
|
|
| if clip_skip is None:
|
| prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| prompt_embeds = prompt_embeds[0]
|
| else:
|
| prompt_embeds = self.text_encoder(
|
| text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| )
|
|
|
|
|
|
|
| prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
|
|
|
|
|
|
|
|
| prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
|
|
| if self.text_encoder is not None:
|
| prompt_embeds_dtype = self.text_encoder.dtype
|
| elif self.unet is not None:
|
| prompt_embeds_dtype = self.unet.dtype
|
| else:
|
| prompt_embeds_dtype = prompt_embeds.dtype
|
|
|
| prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
|
|
| bs_embed, seq_len, _ = prompt_embeds.shape
|
|
|
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
|
|
|
| if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| uncond_tokens: List[str]
|
| if negative_prompt is None:
|
| uncond_tokens = [""] * batch_size
|
| elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| raise TypeError(
|
| f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| f" {type(prompt)}."
|
| )
|
| elif isinstance(negative_prompt, str):
|
| uncond_tokens = [negative_prompt]
|
| elif batch_size != len(negative_prompt):
|
| raise ValueError(
|
| f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| " the batch size of `prompt`."
|
| )
|
| else:
|
| uncond_tokens = negative_prompt
|
|
|
|
|
| if isinstance(self, TextualInversionLoaderMixin):
|
| uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
|
|
| max_length = prompt_embeds.shape[1]
|
| uncond_input = self.tokenizer(
|
| uncond_tokens,
|
| padding="max_length",
|
| max_length=max_length,
|
| truncation=True,
|
| return_tensors="pt",
|
| )
|
|
|
| if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| attention_mask = uncond_input.attention_mask.to(device)
|
| else:
|
| attention_mask = None
|
|
|
| negative_prompt_embeds = self.text_encoder(
|
| uncond_input.input_ids.to(device),
|
| attention_mask=attention_mask,
|
| )
|
| negative_prompt_embeds = negative_prompt_embeds[0]
|
|
|
| if do_classifier_free_guidance:
|
|
|
| seq_len = negative_prompt_embeds.shape[1]
|
|
|
| negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
|
|
| negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
| if self.text_encoder is not None:
|
| if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
|
| unscale_lora_layers(self.text_encoder, lora_scale)
|
|
|
| return prompt_embeds, negative_prompt_embeds
|
|
|
|
|
| def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| dtype = next(self.image_encoder.parameters()).dtype
|
|
|
| if not isinstance(image, torch.Tensor):
|
| image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
|
|
| image = image.to(device=device, dtype=dtype)
|
| if output_hidden_states:
|
| image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| uncond_image_enc_hidden_states = self.image_encoder(
|
| torch.zeros_like(image), output_hidden_states=True
|
| ).hidden_states[-2]
|
| uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| num_images_per_prompt, dim=0
|
| )
|
| return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| else:
|
| image_embeds = self.image_encoder(image).image_embeds
|
| image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| uncond_image_embeds = torch.zeros_like(image_embeds)
|
|
|
| return image_embeds, uncond_image_embeds
|
|
|
|
|
| def prepare_ip_adapter_image_embeds(
|
| self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
| ):
|
| if ip_adapter_image_embeds is None:
|
| if not isinstance(ip_adapter_image, list):
|
| ip_adapter_image = [ip_adapter_image]
|
|
|
| if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
| raise ValueError(
|
| f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
| )
|
|
|
| image_embeds = []
|
| for single_ip_adapter_image, image_proj_layer in zip(
|
| ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
| ):
|
| output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
| single_image_embeds, single_negative_image_embeds = self.encode_image(
|
| single_ip_adapter_image, device, 1, output_hidden_state
|
| )
|
| single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
| single_negative_image_embeds = torch.stack(
|
| [single_negative_image_embeds] * num_images_per_prompt, dim=0
|
| )
|
|
|
| if do_classifier_free_guidance:
|
| single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
| single_image_embeds = single_image_embeds.to(device)
|
|
|
| image_embeds.append(single_image_embeds)
|
| else:
|
| repeat_dims = [1]
|
| image_embeds = []
|
| for single_image_embeds in ip_adapter_image_embeds:
|
| if do_classifier_free_guidance:
|
| single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
| single_image_embeds = single_image_embeds.repeat(
|
| num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
| )
|
| single_negative_image_embeds = single_negative_image_embeds.repeat(
|
| num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
| )
|
| single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
| else:
|
| single_image_embeds = single_image_embeds.repeat(
|
| num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
| )
|
| image_embeds.append(single_image_embeds)
|
|
|
| return image_embeds
|
|
|
|
|
| def run_safety_checker(self, image, device, dtype):
|
| if self.safety_checker is None:
|
| has_nsfw_concept = None
|
| else:
|
| if torch.is_tensor(image):
|
| feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| else:
|
| feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
| image, has_nsfw_concept = self.safety_checker(
|
| images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
| )
|
| return image, has_nsfw_concept
|
|
|
|
|
| def decode_latents(self, latents):
|
| deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
| deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
|
|
| latents = 1 / self.vae.config.scaling_factor * latents
|
| image = self.vae.decode(latents, return_dict=False)[0]
|
| image = (image / 2 + 0.5).clamp(0, 1)
|
|
|
| image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| return image
|
|
|
|
|
| def prepare_extra_step_kwargs(self, generator, eta):
|
|
|
|
|
|
|
|
|
|
|
| accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| extra_step_kwargs = {}
|
| if accepts_eta:
|
| extra_step_kwargs["eta"] = eta
|
|
|
|
|
| accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| if accepts_generator:
|
| extra_step_kwargs["generator"] = generator
|
| return extra_step_kwargs
|
|
|
| def check_inputs(
|
| self,
|
| prompt,
|
| image,
|
| callback_steps,
|
| negative_prompt=None,
|
| prompt_embeds=None,
|
| negative_prompt_embeds=None,
|
| ip_adapter_image=None,
|
| ip_adapter_image_embeds=None,
|
| extra_condition_scale=1.0,
|
| control_guidance_start=0.0,
|
| control_guidance_end=1.0,
|
| callback_on_step_end_tensor_inputs=None,
|
| ):
|
| if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| raise ValueError(
|
| f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| f" {type(callback_steps)}."
|
| )
|
|
|
| if callback_on_step_end_tensor_inputs is not None and not all(
|
| k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| ):
|
| raise ValueError(
|
| f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| )
|
|
|
| if prompt is not None and prompt_embeds is not None:
|
| raise ValueError(
|
| f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| " only forward one of the two."
|
| )
|
| elif prompt is None and prompt_embeds is None:
|
| raise ValueError(
|
| "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| )
|
| elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
|
|
| if negative_prompt is not None and negative_prompt_embeds is not None:
|
| raise ValueError(
|
| f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| )
|
|
|
| if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| raise ValueError(
|
| "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| f" {negative_prompt_embeds.shape}."
|
| )
|
|
|
|
|
| unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
|
| num_extra_conditions = len(unet.extra_condition_names)
|
| if num_extra_conditions == 1:
|
| self.check_image(image, prompt, prompt_embeds)
|
| elif num_extra_conditions > 1:
|
| if not isinstance(image, list):
|
| raise TypeError("For multiple extra conditions: `image` must be type `list`")
|
|
|
|
|
|
|
| elif any(isinstance(i, list) for i in image):
|
| transposed_image = [list(t) for t in zip(*image)]
|
| if len(transposed_image) != num_extra_conditions:
|
| raise ValueError(
|
| f"For multiple extra conditions: if you pass`image` as a list of list, each sublist must have the same length as the number of extra conditions, but the sublists in `image` got {len(transposed_image)} images and {num_extra_conditions} extra conditions."
|
| )
|
| for image_ in transposed_image:
|
| self.check_image(image_, prompt, prompt_embeds)
|
| elif len(image) != num_extra_conditions:
|
| raise ValueError(
|
| f"For multiple extra conditions: `image` must have the same length as the number of extra conditions, but got {len(image)} images and {num_extra_conditions} extra conditions."
|
| )
|
| else:
|
| for image_ in image:
|
| self.check_image(image_, prompt, prompt_embeds)
|
| else:
|
| assert False
|
|
|
|
|
| if num_extra_conditions == 1:
|
| if not isinstance(extra_condition_scale, float):
|
| raise TypeError("For single extra condition: `extra_condition_scale` must be type `float`.")
|
| elif num_extra_conditions >= 1:
|
| if isinstance(extra_condition_scale, list):
|
| if any(isinstance(i, list) for i in extra_condition_scale):
|
| raise ValueError(
|
| "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
|
| "The conditioning scale must be fixed across the batch."
|
| )
|
| elif isinstance(extra_condition_scale, list) and len(extra_condition_scale) != num_extra_conditions:
|
| raise ValueError(
|
| "For multiple extra conditions: When `extra_condition_scale` is specified as `list`, it must have"
|
| " the same length as the number of extra conditions"
|
| )
|
| else:
|
| assert False
|
|
|
| if not isinstance(control_guidance_start, (tuple, list)):
|
| control_guidance_start = [control_guidance_start]
|
|
|
| if not isinstance(control_guidance_end, (tuple, list)):
|
| control_guidance_end = [control_guidance_end]
|
|
|
| if len(control_guidance_start) != len(control_guidance_end):
|
| raise ValueError(
|
| f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
| )
|
|
|
| if num_extra_conditions > 1:
|
| if len(control_guidance_start) != num_extra_conditions:
|
| raise ValueError(
|
| f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {num_extra_conditions} extra conditions available. Make sure to provide {num_extra_conditions}."
|
| )
|
|
|
| for start, end in zip(control_guidance_start, control_guidance_end):
|
| if start >= end:
|
| raise ValueError(
|
| f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
|
| )
|
| if start < 0.0:
|
| raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
|
| if end > 1.0:
|
| raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
|
|
| if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
| raise ValueError(
|
| "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
| )
|
|
|
| if ip_adapter_image_embeds is not None:
|
| if not isinstance(ip_adapter_image_embeds, list):
|
| raise ValueError(
|
| f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
|
| )
|
| elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
|
| raise ValueError(
|
| f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
| )
|
|
|
| def check_image(self, image, prompt, prompt_embeds):
|
| image_is_pil = isinstance(image, PIL.Image.Image)
|
| image_is_tensor = isinstance(image, torch.Tensor)
|
| image_is_np = isinstance(image, np.ndarray)
|
| image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
| image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
| image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
|
|
| if (
|
| not image_is_pil
|
| and not image_is_tensor
|
| and not image_is_np
|
| and not image_is_pil_list
|
| and not image_is_tensor_list
|
| and not image_is_np_list
|
| ):
|
| raise TypeError(
|
| f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
|
| )
|
|
|
| if image_is_pil:
|
| image_batch_size = 1
|
| else:
|
| image_batch_size = len(image)
|
|
|
| if prompt is not None and isinstance(prompt, str):
|
| prompt_batch_size = 1
|
| elif prompt is not None and isinstance(prompt, list):
|
| prompt_batch_size = len(prompt)
|
| elif prompt_embeds is not None:
|
| prompt_batch_size = prompt_embeds.shape[0]
|
|
|
| if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
| raise ValueError(
|
| f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
| )
|
|
|
| def prepare_image(
|
| self,
|
| image,
|
| width,
|
| height,
|
| batch_size,
|
| num_images_per_prompt,
|
| device,
|
| dtype,
|
| do_classifier_free_guidance=False,
|
| ):
|
| image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
| image_batch_size = image.shape[0]
|
|
|
| if image_batch_size == 1:
|
| repeat_by = batch_size
|
| else:
|
|
|
| repeat_by = num_images_per_prompt
|
|
|
| image = image.repeat_interleave(repeat_by, dim=0)
|
|
|
| image = image.to(device=device, dtype=dtype)
|
|
|
| if do_classifier_free_guidance:
|
| image = torch.cat([image] * 2)
|
|
|
| return image
|
|
|
|
|
| def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
| shape = (
|
| batch_size,
|
| num_channels_latents,
|
| int(height) // self.vae_scale_factor,
|
| int(width) // self.vae_scale_factor,
|
| )
|
| if isinstance(generator, list) and len(generator) != batch_size:
|
| raise ValueError(
|
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| )
|
|
|
| if latents is None:
|
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| else:
|
| latents = latents.to(device)
|
|
|
|
|
| latents = latents * self.scheduler.init_noise_sigma
|
| return latents
|
|
|
|
|
| def get_guidance_scale_embedding(
|
| self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
| ) -> torch.Tensor:
|
| """
|
| See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
|
|
| Args:
|
| w (`torch.Tensor`):
|
| Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
| embedding_dim (`int`, *optional*, defaults to 512):
|
| Dimension of the embeddings to generate.
|
| dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
| Data type of the generated embeddings.
|
|
|
| Returns:
|
| `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
| """
|
| assert len(w.shape) == 1
|
| w = w * 1000.0
|
|
|
| half_dim = embedding_dim // 2
|
| emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| emb = w.to(dtype)[:, None] * emb[None, :]
|
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| if embedding_dim % 2 == 1:
|
| emb = torch.nn.functional.pad(emb, (0, 1))
|
| assert emb.shape == (w.shape[0], embedding_dim)
|
| return emb
|
|
|
| @property
|
| def guidance_scale(self):
|
| return self._guidance_scale
|
|
|
| @property
|
| def clip_skip(self):
|
| return self._clip_skip
|
|
|
|
|
|
|
|
|
| @property
|
| def do_classifier_free_guidance(self):
|
| return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
|
|
| @property
|
| def cross_attention_kwargs(self):
|
| return self._cross_attention_kwargs
|
|
|
| @property
|
| def num_timesteps(self):
|
| return self._num_timesteps
|
|
|
| @classmethod
|
| @validate_hf_hub_args
|
| def lora_state_dict(
|
| cls,
|
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
| **kwargs,
|
| ):
|
|
|
| state_dict, network_alphas = super(StableDiffusionControlLoraV3Pipeline, cls).lora_state_dict(
|
| pretrained_model_name_or_path_or_dict, **kwargs
|
| )
|
| if network_alphas is None:
|
| network_alphas = {}
|
| for k, v in state_dict.items():
|
| if ".lora_A." in k:
|
| network_alphas[".".join(k.split(".lora_A.")[0].split(".") + ["alpha"])] = v.shape[0]
|
| return state_dict, network_alphas
|
|
|
| def load_lora_weights(
|
| self,
|
| pretrained_model_name_or_path_or_dict: Union[
|
| Union[str, Dict[str, torch.Tensor]],
|
| List[Union[str, Dict[str, torch.Tensor]]]
|
| ],
|
| adapter_name=None,
|
| **kwargs
|
| ):
|
| unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
|
| num_condition_names = len(unet.extra_condition_names)
|
| in_channels = unet.config.in_channels
|
|
|
| kwargs["weight_name"] = kwargs.pop("weight_name", "pytorch_lora_weights.safetensors")
|
|
|
| if adapter_name is not None and adapter_name not in unet.extra_condition_names:
|
| unet._hf_peft_config_loaded = True
|
| super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
|
| unet.set_adapter(adapter_name)
|
| return
|
|
|
| if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
| pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] * num_condition_names
|
| pretrained_model_name_or_path_or_dict_list = pretrained_model_name_or_path_or_dict
|
|
|
| assert len(pretrained_model_name_or_path_or_dict) == len(unet.extra_condition_names)
|
|
|
| adapter_name_ori = adapter_name
|
| for i, (pretrained_model_name_or_path_or_dict, adapter_name) in enumerate(zip(
|
| pretrained_model_name_or_path_or_dict_list,
|
| unet.extra_condition_names
|
| )):
|
| _kwargs = {**kwargs}
|
| subfolder = _kwargs.pop("subfolder", None)
|
| if isinstance(subfolder, list):
|
| subfolder = subfolder[i]
|
|
|
| if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
| pretrained_model_name_or_path_or_dict, _ = self.lora_state_dict(
|
| pretrained_model_name_or_path_or_dict,
|
| subfolder=subfolder,
|
| **_kwargs
|
| )
|
|
|
| if adapter_name_ori is not None:
|
|
|
| i = unet.extra_condition_names.index(adapter_name_ori)
|
| adapter_name = adapter_name_ori
|
|
|
| unet_conv_in_lora_A_name, old_weight = ([
|
| (k, v)
|
| for k, v in pretrained_model_name_or_path_or_dict.items()
|
| if "unet." in k and ".conv_in." in k and ".lora_A." in k
|
| ] + [(None, None)])[0]
|
| if unet_conv_in_lora_A_name is not None:
|
| in_weight = old_weight[:,:in_channels]
|
| cond_weight = old_weight[:,in_channels:]
|
| zero_weight = torch.zeros_like(in_weight)
|
| new_weight = torch.cat(
|
| [in_weight] +
|
| [zero_weight] * i +
|
| [cond_weight] +
|
| [zero_weight] * (num_condition_names - i - 1),
|
| dim=1
|
| )
|
| pretrained_model_name_or_path_or_dict[unet_conv_in_lora_A_name] = new_weight
|
|
|
| super().load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name, **_kwargs)
|
|
|
| if adapter_name_ori is not None:
|
| break
|
|
|
| unet.activate_extra_condition_adapters()
|
|
|
| @torch.no_grad()
|
| @replace_example_docstring(EXAMPLE_DOC_STRING)
|
| def __call__(
|
| self,
|
| prompt: Union[str, List[str]] = None,
|
| image: PipelineImageInput = None,
|
| height: Optional[int] = None,
|
| width: Optional[int] = None,
|
| num_inference_steps: int = 50,
|
| timesteps: List[int] = None,
|
| sigmas: List[float] = None,
|
| guidance_scale: float = 7.5,
|
| negative_prompt: Optional[Union[str, List[str]]] = None,
|
| num_images_per_prompt: Optional[int] = 1,
|
| eta: float = 0.0,
|
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| latents: Optional[torch.Tensor] = None,
|
| prompt_embeds: Optional[torch.Tensor] = None,
|
| negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| ip_adapter_image: Optional[PipelineImageInput] = None,
|
| ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| output_type: Optional[str] = "pil",
|
| return_dict: bool = True,
|
| cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| extra_condition_scale: Union[float, List[float]] = 1.0,
|
| control_guidance_start: Union[float, List[float]] = 0.0,
|
| control_guidance_end: Union[float, List[float]] = 1.0,
|
| clip_skip: Optional[int] = None,
|
| callback_on_step_end: Optional[
|
| Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| ] = None,
|
| callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| **kwargs,
|
| ):
|
| r"""
|
| The call function to the pipeline for generation.
|
|
|
| Args:
|
| prompt (`str` or `List[str]`, *optional*):
|
| The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
| `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
| The extra input condition to provide guidance to the `unet` for generation after encoded by `vae`. If the type is
|
| specified as `torch.Tensor`, its `vae` latent representation is passed to UNet. `PIL.Image.Image` can also be accepted
|
| as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
| width are passed, `image` is resized accordingly. If multiple extra conditions are specified in `unet`,
|
| images must be passed as a list such that each element of the list can be correctly batched for input
|
| to `unet`. When `prompt` is a list, and if a list of images is passed for `unet`, each will be paired with each prompt
|
| in the `prompt` list. This also applies to multiple extra conditions, where a list of image lists can be
|
| passed to batch for each prompt and each extra condition.
|
| height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| The height in pixels of the generated image.
|
| width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| The width in pixels of the generated image.
|
| num_inference_steps (`int`, *optional*, defaults to 50):
|
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| expense of slower inference.
|
| timesteps (`List[int]`, *optional*):
|
| Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| passed will be used. Must be in descending order.
|
| sigmas (`List[float]`, *optional*):
|
| Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| will be used.
|
| guidance_scale (`float`, *optional*, defaults to 7.5):
|
| A higher guidance scale value encourages the model to generate images closely linked to the text
|
| `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| negative_prompt (`str` or `List[str]`, *optional*):
|
| The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| The number of images to generate per prompt.
|
| eta (`float`, *optional*, defaults to 0.0):
|
| Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| generation deterministic.
|
| latents (`torch.Tensor`, *optional*):
|
| Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| tensor is generated by sampling using the supplied random `generator`.
|
| prompt_embeds (`torch.Tensor`, *optional*):
|
| Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| provided, text embeddings are generated from the `prompt` input argument.
|
| negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
| contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
| provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| output_type (`str`, *optional*, defaults to `"pil"`):
|
| The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| return_dict (`bool`, *optional*, defaults to `True`):
|
| Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| plain tuple.
|
| callback (`Callable`, *optional*):
|
| A function that calls every `callback_steps` steps during inference. The function is called with the
|
| following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| callback_steps (`int`, *optional*, defaults to 1):
|
| The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| every step.
|
| cross_attention_kwargs (`dict`, *optional*):
|
| A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| extra_condition_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| The control lora scale of `unet`. If multiple extra conditions are specified in `unet`, you can set
|
| the corresponding scale as a list.
|
| control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
| The percentage of total steps at which the extra condtion starts applying.
|
| control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| The percentage of total steps at which the extra condtion stops applying.
|
| clip_skip (`int`, *optional*):
|
| Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| the output of the pre-final layer will be used for computing the prompt embeddings.
|
| callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| `._callback_tensor_inputs` attribute of your pipeline class.
|
|
|
| Examples:
|
|
|
| Returns:
|
| [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| "not-safe-for-work" (nsfw) content.
|
| """
|
|
|
| callback = kwargs.pop("callback", None)
|
| callback_steps = kwargs.pop("callback_steps", None)
|
|
|
| if callback is not None:
|
| deprecate(
|
| "callback",
|
| "1.0.0",
|
| "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| )
|
| if callback_steps is not None:
|
| deprecate(
|
| "callback_steps",
|
| "1.0.0",
|
| "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| )
|
|
|
| if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
|
|
| unet: UNet2DConditionModelEx = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
|
| num_extra_conditions = len(unet.extra_condition_names)
|
|
|
|
|
| if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
| control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
| elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
| control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
| elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
| mult = num_extra_conditions
|
| control_guidance_start, control_guidance_end = (
|
| mult * [control_guidance_start],
|
| mult * [control_guidance_end],
|
| )
|
|
|
|
|
| self.check_inputs(
|
| prompt,
|
| image,
|
| callback_steps,
|
| negative_prompt,
|
| prompt_embeds,
|
| negative_prompt_embeds,
|
| ip_adapter_image,
|
| ip_adapter_image_embeds,
|
| extra_condition_scale,
|
| control_guidance_start,
|
| control_guidance_end,
|
| callback_on_step_end_tensor_inputs,
|
| )
|
|
|
| self._guidance_scale = guidance_scale
|
| self._clip_skip = clip_skip
|
| self._cross_attention_kwargs = cross_attention_kwargs
|
|
|
|
|
| if prompt is not None and isinstance(prompt, str):
|
| batch_size = 1
|
| elif prompt is not None and isinstance(prompt, list):
|
| batch_size = len(prompt)
|
| else:
|
| batch_size = prompt_embeds.shape[0]
|
|
|
| device = self._execution_device
|
|
|
| if num_extra_conditions > 1 and isinstance(extra_condition_scale, float):
|
| extra_condition_scale = [extra_condition_scale] * num_extra_conditions
|
|
|
|
|
| text_encoder_lora_scale = (
|
| self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| )
|
| prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| prompt,
|
| device,
|
| num_images_per_prompt,
|
| self.do_classifier_free_guidance,
|
| negative_prompt,
|
| prompt_embeds=prompt_embeds,
|
| negative_prompt_embeds=negative_prompt_embeds,
|
| lora_scale=text_encoder_lora_scale,
|
| clip_skip=self.clip_skip,
|
| )
|
|
|
|
|
|
|
| if self.do_classifier_free_guidance:
|
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
|
|
| if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| image_embeds = self.prepare_ip_adapter_image_embeds(
|
| ip_adapter_image,
|
| ip_adapter_image_embeds,
|
| device,
|
| batch_size * num_images_per_prompt,
|
| self.do_classifier_free_guidance,
|
| )
|
|
|
|
|
| if num_extra_conditions == 1:
|
| image = self.prepare_image(
|
| image=image,
|
| width=width,
|
| height=height,
|
| batch_size=batch_size * num_images_per_prompt,
|
| num_images_per_prompt=num_images_per_prompt,
|
| device=device,
|
| dtype=unet.dtype,
|
| do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| )
|
| height, width = image.shape[-2:]
|
| image = (
|
| self.vae.encode(image.to(dtype=unet.dtype)).latent_dist.mode() * self.vae.config.scaling_factor
|
| )
|
| elif num_extra_conditions >= 1:
|
| images = []
|
|
|
|
|
| if isinstance(image[0], list):
|
|
|
| image = [list(t) for t in zip(*image)]
|
|
|
| for image_ in image:
|
| image_ = self.prepare_image(
|
| image=image_,
|
| width=width,
|
| height=height,
|
| batch_size=batch_size * num_images_per_prompt,
|
| num_images_per_prompt=num_images_per_prompt,
|
| device=device,
|
| dtype=unet.dtype,
|
| do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| )
|
|
|
| images.append(image_)
|
|
|
| image = images
|
| height, width = image[0].shape[-2:]
|
| image = [
|
| self.vae.encode(image.to(dtype=unet.dtype)).latent_dist.mode() * self.vae.config.scaling_factor
|
| for image in images
|
| ]
|
| else:
|
| assert False
|
|
|
|
|
| timesteps, num_inference_steps = retrieve_timesteps(
|
| self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| )
|
| self._num_timesteps = len(timesteps)
|
|
|
|
|
| num_channels_latents = self.unet.config.in_channels
|
| latents = self.prepare_latents(
|
| batch_size * num_images_per_prompt,
|
| num_channels_latents,
|
| height,
|
| width,
|
| prompt_embeds.dtype,
|
| device,
|
| generator,
|
| latents,
|
| )
|
|
|
|
|
| timestep_cond = None
|
| if self.unet.config.time_cond_proj_dim is not None:
|
| guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
| timestep_cond = self.get_guidance_scale_embedding(
|
| guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| ).to(device=device, dtype=latents.dtype)
|
|
|
|
|
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
|
|
|
| added_cond_kwargs = (
|
| {"image_embeds": image_embeds}
|
| if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
| else None
|
| )
|
|
|
|
|
| extra_condition_keep = []
|
| for i in range(len(timesteps)):
|
| keeps = [
|
| 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
| for s, e in zip(control_guidance_start, control_guidance_end)
|
| ]
|
| extra_condition_keep.append(keeps[0] if num_extra_conditions == 1 else keeps)
|
|
|
|
|
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| is_unet_compiled = is_compiled_module(self.unet)
|
| is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
| with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| for i, t in enumerate(timesteps):
|
|
|
|
|
| if is_unet_compiled and is_torch_higher_equal_2_1:
|
| torch._inductor.cudagraph_mark_step_begin()
|
|
|
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
| if isinstance(extra_condition_keep[i], list):
|
| cond_scale = [c * s for c, s in zip(extra_condition_scale, extra_condition_keep[i])]
|
| else:
|
| extra_cond_scale = extra_condition_scale
|
| if isinstance(extra_cond_scale, list):
|
| extra_cond_scale = extra_cond_scale[0]
|
| cond_scale = extra_cond_scale * extra_condition_keep[i]
|
|
|
| self.unet.set_extra_condition_scale(cond_scale)
|
|
|
|
|
| noise_pred = self.unet(
|
| latent_model_input,
|
| t,
|
| encoder_hidden_states=prompt_embeds,
|
| timestep_cond=timestep_cond,
|
| cross_attention_kwargs=self.cross_attention_kwargs,
|
| added_cond_kwargs=added_cond_kwargs,
|
| extra_conditions=image,
|
| return_dict=False,
|
| )[0]
|
|
|
|
|
| if self.do_classifier_free_guidance:
|
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
|
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
|
|
| if callback_on_step_end is not None:
|
| callback_kwargs = {}
|
| for k in callback_on_step_end_tensor_inputs:
|
| callback_kwargs[k] = locals()[k]
|
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
|
|
| latents = callback_outputs.pop("latents", latents)
|
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
|
|
|
|
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| progress_bar.update()
|
| if callback is not None and i % callback_steps == 0:
|
| step_idx = i // getattr(self.scheduler, "order", 1)
|
| callback(step_idx, t, latents)
|
|
|
| self.unet.set_extra_condition_scale(1.0)
|
|
|
|
|
|
|
| if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| self.unet.to("cpu")
|
| torch.cuda.empty_cache()
|
|
|
| if not output_type == "latent":
|
| image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
| 0
|
| ]
|
| image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| else:
|
| image = latents
|
| has_nsfw_concept = None
|
|
|
| if has_nsfw_concept is None:
|
| do_denormalize = [True] * image.shape[0]
|
| else:
|
| do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
|
|
| image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
|
|
|
|
| self.maybe_free_model_hooks()
|
|
|
| if not return_dict:
|
| return (image, has_nsfw_concept)
|
|
|
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
|
|