| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import inspect |
| | from dataclasses import dataclass |
| | from typing import Any, Callable, Dict, List, Optional, Union |
| |
|
| | import numpy as np |
| | import PIL.Image |
| | import torch |
| | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer |
| |
|
| | from ...image_processor import VaeImageProcessorLDM3D |
| | from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin |
| | from ...models import AutoencoderKL, UNet2DConditionModel |
| | from ...models.lora import adjust_lora_scale_text_encoder |
| | from ...schedulers import KarrasDiffusionSchedulers |
| | from ...utils import ( |
| | USE_PEFT_BACKEND, |
| | BaseOutput, |
| | deprecate, |
| | logging, |
| | replace_example_docstring, |
| | scale_lora_layers, |
| | unscale_lora_layers, |
| | ) |
| | from ...utils.torch_utils import randn_tensor |
| | from ..pipeline_utils import DiffusionPipeline |
| | from .safety_checker import StableDiffusionSafetyChecker |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | EXAMPLE_DOC_STRING = """ |
| | Examples: |
| | ```python |
| | >>> from diffusers import StableDiffusionLDM3DPipeline |
| | |
| | >>> pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c") |
| | >>> pipe = pipe.to("cuda") |
| | |
| | >>> prompt = "a photo of an astronaut riding a horse on mars" |
| | >>> output = pipe(prompt) |
| | >>> rgb_image, depth_image = output.rgb, output.depth |
| | >>> rgb_image[0].save("astronaut_ldm3d_rgb.jpg") |
| | >>> depth_image[0].save("astronaut_ldm3d_depth.png") |
| | ``` |
| | """ |
| |
|
| |
|
| | @dataclass |
| | class LDM3DPipelineOutput(BaseOutput): |
| | """ |
| | Output class for Stable Diffusion pipelines. |
| | |
| | Args: |
| | rgb (`List[PIL.Image.Image]` or `np.ndarray`) |
| | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, |
| | num_channels)`. |
| | depth (`List[PIL.Image.Image]` or `np.ndarray`) |
| | List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, |
| | num_channels)`. |
| | nsfw_content_detected (`List[bool]`) |
| | List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or |
| | `None` if safety checking could not be performed. |
| | """ |
| |
|
| | rgb: Union[List[PIL.Image.Image], np.ndarray] |
| | depth: Union[List[PIL.Image.Image], np.ndarray] |
| | nsfw_content_detected: Optional[List[bool]] |
| |
|
| |
|
| | class StableDiffusionLDM3DPipeline( |
| | DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin |
| | ): |
| | r""" |
| | Pipeline for text-to-image and 3D generation using LDM3D. |
| | |
| | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods |
| | implemented for all pipelines (downloading, saving, running on a particular device, etc.). |
| | |
| | The pipeline also inherits the following loading methods: |
| | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings |
| | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights |
| | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights |
| | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files |
| | |
| | Args: |
| | vae ([`AutoencoderKL`]): |
| | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. |
| | text_encoder ([`~transformers.CLIPTextModel`]): |
| | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). |
| | tokenizer ([`~transformers.CLIPTokenizer`]): |
| | A `CLIPTokenizer` to tokenize text. |
| | unet ([`UNet2DConditionModel`]): |
| | A `UNet2DConditionModel` to denoise the encoded image latents. |
| | scheduler ([`SchedulerMixin`]): |
| | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
| | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
| | safety_checker ([`StableDiffusionSafetyChecker`]): |
| | Classification module that estimates whether generated images could be considered offensive or harmful. |
| | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details |
| | about a model's potential harms. |
| | feature_extractor ([`~transformers.CLIPImageProcessor`]): |
| | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. |
| | """ |
| | model_cpu_offload_seq = "text_encoder->unet->vae" |
| | _optional_components = ["safety_checker", "feature_extractor"] |
| | _exclude_from_cpu_offload = ["safety_checker"] |
| |
|
| | def __init__( |
| | self, |
| | vae: AutoencoderKL, |
| | text_encoder: CLIPTextModel, |
| | tokenizer: CLIPTokenizer, |
| | unet: UNet2DConditionModel, |
| | scheduler: KarrasDiffusionSchedulers, |
| | safety_checker: StableDiffusionSafetyChecker, |
| | feature_extractor: CLIPImageProcessor, |
| | requires_safety_checker: bool = True, |
| | ): |
| | super().__init__() |
| |
|
| | if safety_checker is None and requires_safety_checker: |
| | logger.warning( |
| | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" |
| | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" |
| | " results in services or applications open to the public. Both the diffusers team and Hugging Face" |
| | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" |
| | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" |
| | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." |
| | ) |
| |
|
| | if safety_checker is not None and feature_extractor is None: |
| | raise ValueError( |
| | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" |
| | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." |
| | ) |
| |
|
| | self.register_modules( |
| | vae=vae, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | unet=unet, |
| | scheduler=scheduler, |
| | safety_checker=safety_checker, |
| | feature_extractor=feature_extractor, |
| | ) |
| | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| | self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) |
| | self.register_to_config(requires_safety_checker=requires_safety_checker) |
| |
|
| | |
| | def enable_vae_slicing(self): |
| | r""" |
| | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to |
| | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. |
| | """ |
| | self.vae.enable_slicing() |
| |
|
| | |
| | def disable_vae_slicing(self): |
| | r""" |
| | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to |
| | computing decoding in one step. |
| | """ |
| | self.vae.disable_slicing() |
| |
|
| | |
| | def enable_vae_tiling(self): |
| | r""" |
| | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to |
| | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow |
| | processing larger images. |
| | """ |
| | self.vae.enable_tiling() |
| |
|
| | |
| | def disable_vae_tiling(self): |
| | r""" |
| | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to |
| | computing decoding in one step. |
| | """ |
| | self.vae.disable_tiling() |
| |
|
| | |
| | def _encode_prompt( |
| | self, |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt=None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | lora_scale: Optional[float] = None, |
| | **kwargs, |
| | ): |
| | deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." |
| | deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) |
| |
|
| | prompt_embeds_tuple = self.encode_prompt( |
| | prompt=prompt, |
| | device=device, |
| | num_images_per_prompt=num_images_per_prompt, |
| | do_classifier_free_guidance=do_classifier_free_guidance, |
| | negative_prompt=negative_prompt, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | lora_scale=lora_scale, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) |
| |
|
| | return prompt_embeds |
| |
|
| | |
| | def encode_prompt( |
| | self, |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt=None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | lora_scale: Optional[float] = None, |
| | clip_skip: Optional[int] = None, |
| | ): |
| | r""" |
| | Encodes the prompt into text encoder hidden states. |
| | |
| | Args: |
| | prompt (`str` or `List[str]`, *optional*): |
| | prompt to be encoded |
| | device: (`torch.device`): |
| | torch device |
| | num_images_per_prompt (`int`): |
| | number of images that should be generated per prompt |
| | do_classifier_free_guidance (`bool`): |
| | whether to use classifier free guidance or not |
| | negative_prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts not to guide the image generation. If not defined, one has to pass |
| | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
| | less than `1`). |
| | prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
| | provided, text embeddings will be generated from `prompt` input argument. |
| | negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
| | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
| | argument. |
| | lora_scale (`float`, *optional*): |
| | A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. |
| | clip_skip (`int`, *optional*): |
| | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that |
| | the output of the pre-final layer will be used for computing the prompt embeddings. |
| | """ |
| | |
| | |
| | if lora_scale is not None and isinstance(self, LoraLoaderMixin): |
| | self._lora_scale = lora_scale |
| |
|
| | |
| | if not USE_PEFT_BACKEND: |
| | adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) |
| | else: |
| | scale_lora_layers(self.text_encoder, lora_scale) |
| |
|
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | if prompt_embeds is None: |
| | |
| | if isinstance(self, TextualInversionLoaderMixin): |
| | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) |
| |
|
| | text_inputs = self.tokenizer( |
| | prompt, |
| | padding="max_length", |
| | max_length=self.tokenizer.model_max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | text_input_ids = text_inputs.input_ids |
| | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids |
| |
|
| | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
| | text_input_ids, untruncated_ids |
| | ): |
| | removed_text = self.tokenizer.batch_decode( |
| | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] |
| | ) |
| | logger.warning( |
| | "The following part of your input was truncated because CLIP can only handle sequences up to" |
| | f" {self.tokenizer.model_max_length} tokens: {removed_text}" |
| | ) |
| |
|
| | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
| | attention_mask = text_inputs.attention_mask.to(device) |
| | else: |
| | attention_mask = None |
| |
|
| | if clip_skip is None: |
| | prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) |
| | prompt_embeds = prompt_embeds[0] |
| | else: |
| | prompt_embeds = self.text_encoder( |
| | text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True |
| | ) |
| | |
| | |
| | |
| | prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] |
| | |
| | |
| | |
| | |
| | prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) |
| |
|
| | if self.text_encoder is not None: |
| | prompt_embeds_dtype = self.text_encoder.dtype |
| | elif self.unet is not None: |
| | prompt_embeds_dtype = self.unet.dtype |
| | else: |
| | prompt_embeds_dtype = prompt_embeds.dtype |
| |
|
| | prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) |
| |
|
| | bs_embed, seq_len, _ = prompt_embeds.shape |
| | |
| | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) |
| |
|
| | |
| | if do_classifier_free_guidance and negative_prompt_embeds is None: |
| | uncond_tokens: List[str] |
| | if negative_prompt is None: |
| | uncond_tokens = [""] * batch_size |
| | elif prompt is not None and type(prompt) is not type(negative_prompt): |
| | raise TypeError( |
| | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
| | f" {type(prompt)}." |
| | ) |
| | elif isinstance(negative_prompt, str): |
| | uncond_tokens = [negative_prompt] |
| | elif batch_size != len(negative_prompt): |
| | raise ValueError( |
| | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
| | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
| | " the batch size of `prompt`." |
| | ) |
| | else: |
| | uncond_tokens = negative_prompt |
| |
|
| | |
| | if isinstance(self, TextualInversionLoaderMixin): |
| | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) |
| |
|
| | max_length = prompt_embeds.shape[1] |
| | uncond_input = self.tokenizer( |
| | uncond_tokens, |
| | padding="max_length", |
| | max_length=max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| |
|
| | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
| | attention_mask = uncond_input.attention_mask.to(device) |
| | else: |
| | attention_mask = None |
| |
|
| | negative_prompt_embeds = self.text_encoder( |
| | uncond_input.input_ids.to(device), |
| | attention_mask=attention_mask, |
| | ) |
| | negative_prompt_embeds = negative_prompt_embeds[0] |
| |
|
| | if do_classifier_free_guidance: |
| | |
| | seq_len = negative_prompt_embeds.shape[1] |
| |
|
| | negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) |
| |
|
| | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
| |
|
| | if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: |
| | |
| | unscale_lora_layers(self.text_encoder, lora_scale) |
| |
|
| | return prompt_embeds, negative_prompt_embeds |
| |
|
| | def run_safety_checker(self, image, device, dtype): |
| | if self.safety_checker is None: |
| | has_nsfw_concept = None |
| | else: |
| | if torch.is_tensor(image): |
| | feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") |
| | else: |
| | feature_extractor_input = self.image_processor.numpy_to_pil(image) |
| | rgb_feature_extractor_input = feature_extractor_input[0] |
| | safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device) |
| | image, has_nsfw_concept = self.safety_checker( |
| | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) |
| | ) |
| | return image, has_nsfw_concept |
| |
|
| | |
| | def prepare_extra_step_kwargs(self, generator, eta): |
| | |
| | |
| | |
| | |
| |
|
| | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| | extra_step_kwargs = {} |
| | if accepts_eta: |
| | extra_step_kwargs["eta"] = eta |
| |
|
| | |
| | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| | if accepts_generator: |
| | extra_step_kwargs["generator"] = generator |
| | return extra_step_kwargs |
| |
|
| | |
| | def check_inputs( |
| | self, |
| | prompt, |
| | height, |
| | width, |
| | callback_steps, |
| | negative_prompt=None, |
| | prompt_embeds=None, |
| | negative_prompt_embeds=None, |
| | callback_on_step_end_tensor_inputs=None, |
| | ): |
| | if height % 8 != 0 or width % 8 != 0: |
| | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
| |
|
| | if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): |
| | raise ValueError( |
| | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
| | f" {type(callback_steps)}." |
| | ) |
| | if callback_on_step_end_tensor_inputs is not None and not all( |
| | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs |
| | ): |
| | raise ValueError( |
| | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" |
| | ) |
| |
|
| | if prompt is not None and prompt_embeds is not None: |
| | raise ValueError( |
| | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
| | " only forward one of the two." |
| | ) |
| | elif prompt is None and prompt_embeds is None: |
| | raise ValueError( |
| | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." |
| | ) |
| | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
| | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
| |
|
| | if negative_prompt is not None and negative_prompt_embeds is not None: |
| | raise ValueError( |
| | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" |
| | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
| | ) |
| |
|
| | if prompt_embeds is not None and negative_prompt_embeds is not None: |
| | if prompt_embeds.shape != negative_prompt_embeds.shape: |
| | raise ValueError( |
| | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" |
| | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" |
| | f" {negative_prompt_embeds.shape}." |
| | ) |
| |
|
| | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
| | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
| | if isinstance(generator, list) and len(generator) != batch_size: |
| | raise ValueError( |
| | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| | f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
| | ) |
| |
|
| | if latents is None: |
| | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| | else: |
| | latents = latents.to(device) |
| |
|
| | |
| | latents = latents * self.scheduler.init_noise_sigma |
| | return latents |
| |
|
| | @torch.no_grad() |
| | @replace_example_docstring(EXAMPLE_DOC_STRING) |
| | def __call__( |
| | self, |
| | prompt: Union[str, List[str]] = None, |
| | height: Optional[int] = None, |
| | width: Optional[int] = None, |
| | num_inference_steps: int = 49, |
| | guidance_scale: float = 5.0, |
| | negative_prompt: Optional[Union[str, List[str]]] = None, |
| | num_images_per_prompt: Optional[int] = 1, |
| | eta: float = 0.0, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | latents: Optional[torch.FloatTensor] = None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| | callback_steps: int = 1, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | clip_skip: Optional[int] = None, |
| | ): |
| | r""" |
| | The call function to the pipeline for generation. |
| | |
| | Args: |
| | prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. |
| | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
| | The height in pixels of the generated image. |
| | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
| | The width in pixels of the generated image. |
| | num_inference_steps (`int`, *optional*, defaults to 50): |
| | The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| | expense of slower inference. |
| | guidance_scale (`float`, *optional*, defaults to 5.0): |
| | A higher guidance scale value encourages the model to generate images closely linked to the text |
| | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. |
| | negative_prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts to guide what to not include in image generation. If not defined, you need to |
| | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). |
| | num_images_per_prompt (`int`, *optional*, defaults to 1): |
| | The number of images to generate per prompt. |
| | eta (`float`, *optional*, defaults to 0.0): |
| | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies |
| | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. |
| | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
| | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
| | generation deterministic. |
| | latents (`torch.FloatTensor`, *optional*): |
| | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image |
| | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
| | tensor is generated by sampling using the supplied random `generator`. |
| | prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not |
| | provided, text embeddings are generated from the `prompt` input argument. |
| | negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If |
| | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. |
| | output_type (`str`, *optional*, defaults to `"pil"`): |
| | The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
| | plain tuple. |
| | callback (`Callable`, *optional*): |
| | A function that calls every `callback_steps` steps during inference. The function is called with the |
| | following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
| | callback_steps (`int`, *optional*, defaults to 1): |
| | The frequency at which the `callback` function is called. If not specified, the callback is called at |
| | every step. |
| | cross_attention_kwargs (`dict`, *optional*): |
| | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in |
| | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
| | clip_skip (`int`, *optional*): |
| | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that |
| | the output of the pre-final layer will be used for computing the prompt embeddings. |
| | Examples: |
| | |
| | Returns: |
| | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
| | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, |
| | otherwise a `tuple` is returned where the first element is a list with the generated images and the |
| | second element is a list of `bool`s indicating whether the corresponding generated image contains |
| | "not-safe-for-work" (nsfw) content. |
| | """ |
| | |
| | height = height or self.unet.config.sample_size * self.vae_scale_factor |
| | width = width or self.unet.config.sample_size * self.vae_scale_factor |
| |
|
| | |
| | self.check_inputs( |
| | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds |
| | ) |
| |
|
| | |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | device = self._execution_device |
| | |
| | |
| | |
| | do_classifier_free_guidance = guidance_scale > 1.0 |
| |
|
| | |
| | prompt_embeds, negative_prompt_embeds = self.encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | clip_skip=clip_skip, |
| | ) |
| | |
| | |
| | |
| | if do_classifier_free_guidance: |
| | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
| |
|
| | |
| | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| | timesteps = self.scheduler.timesteps |
| |
|
| | |
| | num_channels_latents = self.unet.config.in_channels |
| | latents = self.prepare_latents( |
| | batch_size * num_images_per_prompt, |
| | num_channels_latents, |
| | height, |
| | width, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | latents, |
| | ) |
| |
|
| | |
| | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
| |
|
| | |
| | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| | for i, t in enumerate(timesteps): |
| | |
| | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | noise_pred = self.unet( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=prompt_embeds, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | |
| | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
| |
|
| | |
| | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| | progress_bar.update() |
| | if callback is not None and i % callback_steps == 0: |
| | step_idx = i // getattr(self.scheduler, "order", 1) |
| | callback(step_idx, t, latents) |
| |
|
| | if not output_type == "latent": |
| | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
| | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
| | else: |
| | image = latents |
| | has_nsfw_concept = None |
| |
|
| | if has_nsfw_concept is None: |
| | do_denormalize = [True] * image.shape[0] |
| | else: |
| | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
| |
|
| | rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
| |
|
| | |
| | self.maybe_free_model_hooks() |
| |
|
| | if not return_dict: |
| | return ((rgb, depth), has_nsfw_concept) |
| |
|
| | return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept) |
| |
|