| | import math |
| | import numbers |
| | from typing import Any, Callable, Dict, List, Optional, Union |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| | from diffusers.image_processor import PipelineImageInput |
| | from diffusers.models import AsymmetricAutoencoderKL, ImageProjection |
| | from diffusers.models.attention_processor import Attention, AttnProcessor |
| | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
| | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import ( |
| | StableDiffusionInpaintPipeline, |
| | retrieve_timesteps, |
| | ) |
| | from diffusers.utils import deprecate |
| |
|
| |
|
| | class RASGAttnProcessor: |
| | def __init__(self, mask, token_idx, scale_factor): |
| | self.attention_scores = None |
| | self.mask = mask |
| | self.token_idx = token_idx |
| | self.scale_factor = scale_factor |
| | self.mask_resoltuion = mask.shape[-1] * mask.shape[-2] |
| |
|
| | def __call__( |
| | self, |
| | attn: Attention, |
| | hidden_states: torch.Tensor, |
| | encoder_hidden_states: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | temb: Optional[torch.Tensor] = None, |
| | scale: float = 1.0, |
| | ) -> torch.Tensor: |
| | |
| | downscale_factor = self.mask_resoltuion // hidden_states.shape[1] |
| | residual = hidden_states |
| |
|
| | if attn.spatial_norm is not None: |
| | hidden_states = attn.spatial_norm(hidden_states, temb) |
| |
|
| | input_ndim = hidden_states.ndim |
| |
|
| | if input_ndim == 4: |
| | batch_size, channel, height, width = hidden_states.shape |
| | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
| |
|
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| | ) |
| | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| |
|
| | if attn.group_norm is not None: |
| | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
| |
|
| | query = attn.to_q(hidden_states) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | key = attn.to_k(encoder_hidden_states) |
| | value = attn.to_v(encoder_hidden_states) |
| |
|
| | query = attn.head_to_batch_dim(query) |
| | key = attn.head_to_batch_dim(key) |
| | value = attn.head_to_batch_dim(value) |
| |
|
| | |
| | |
| | if downscale_factor == self.scale_factor**2: |
| | self.attention_scores = get_attention_scores(attn, query, key, attention_mask) |
| | attention_probs = self.attention_scores.softmax(dim=-1) |
| | attention_probs = attention_probs.to(query.dtype) |
| | else: |
| | attention_probs = attn.get_attention_scores(query, key, attention_mask) |
| |
|
| | hidden_states = torch.bmm(attention_probs, value) |
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | if input_ndim == 4: |
| | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
| |
|
| | if attn.residual_connection: |
| | hidden_states = hidden_states + residual |
| |
|
| | hidden_states = hidden_states / attn.rescale_output_factor |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class PAIntAAttnProcessor: |
| | def __init__(self, transformer_block, mask, token_idx, do_classifier_free_guidance, scale_factors): |
| | self.transformer_block = transformer_block |
| | self.mask = mask |
| | self.scale_factors = scale_factors |
| | self.do_classifier_free_guidance = do_classifier_free_guidance |
| | self.token_idx = token_idx |
| | self.shape = mask.shape[2:] |
| | self.mask_resoltuion = mask.shape[-1] * mask.shape[-2] |
| | self.default_processor = AttnProcessor() |
| |
|
| | def __call__( |
| | self, |
| | attn: Attention, |
| | hidden_states: torch.Tensor, |
| | encoder_hidden_states: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | temb: Optional[torch.Tensor] = None, |
| | scale: float = 1.0, |
| | ) -> torch.Tensor: |
| | |
| | downscale_factor = self.mask_resoltuion // hidden_states.shape[1] |
| |
|
| | mask = None |
| | for factor in self.scale_factors: |
| | if downscale_factor == factor**2: |
| | shape = (self.shape[0] // factor, self.shape[1] // factor) |
| | mask = F.interpolate(self.mask, shape, mode="bicubic") |
| | break |
| | if mask is None: |
| | return self.default_processor(attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale) |
| |
|
| | |
| | residual = hidden_states |
| | |
| | input_hidden_states = hidden_states |
| |
|
| | |
| | |
| | |
| |
|
| | if attn.spatial_norm is not None: |
| | hidden_states = attn.spatial_norm(hidden_states, temb) |
| |
|
| | input_ndim = hidden_states.ndim |
| |
|
| | if input_ndim == 4: |
| | batch_size, channel, height, width = hidden_states.shape |
| | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
| |
|
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| | ) |
| | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| |
|
| | if attn.group_norm is not None: |
| | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
| |
|
| | query = attn.to_q(hidden_states) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | key = attn.to_k(encoder_hidden_states) |
| | value = attn.to_v(encoder_hidden_states) |
| |
|
| | query = attn.head_to_batch_dim(query) |
| | key = attn.head_to_batch_dim(key) |
| | value = attn.head_to_batch_dim(value) |
| |
|
| | |
| | self_attention_scores = get_attention_scores( |
| | attn, query, key, attention_mask |
| | ) |
| | self_attention_probs = self_attention_scores.softmax( |
| | dim=-1 |
| | ) |
| | self_attention_probs = self_attention_probs.to(query.dtype) |
| |
|
| | hidden_states = torch.bmm(self_attention_probs, value) |
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | |
| |
|
| | if input_ndim == 4: |
| | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
| |
|
| | if attn.residual_connection: |
| | hidden_states = hidden_states + residual |
| |
|
| | self_attention_output_hidden_states = hidden_states / attn.rescale_output_factor |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | unnormalized_input_hidden_states = ( |
| | input_hidden_states + self.transformer_block.norm1.bias |
| | ) * self.transformer_block.norm1.weight |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | transformer_hidden_states = self_attention_output_hidden_states + unnormalized_input_hidden_states |
| | if transformer_hidden_states.ndim == 4: |
| | transformer_hidden_states = transformer_hidden_states.squeeze(1) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if self.transformer_block.use_ada_layer_norm: |
| | |
| | raise NotImplementedError() |
| | elif self.transformer_block.use_ada_layer_norm_zero or self.transformer_block.use_layer_norm: |
| | transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states) |
| | elif self.transformer_block.use_ada_layer_norm_single: |
| | |
| | |
| | transformer_norm_hidden_states = transformer_hidden_states |
| | elif self.transformer_block.use_ada_layer_norm_continuous: |
| | |
| | raise NotImplementedError() |
| | else: |
| | raise ValueError("Incorrect norm") |
| |
|
| | if self.transformer_block.pos_embed is not None and self.transformer_block.use_ada_layer_norm_single is False: |
| | transformer_norm_hidden_states = self.transformer_block.pos_embed(transformer_norm_hidden_states) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | cross_attention_input_hidden_states = ( |
| | transformer_norm_hidden_states |
| | ) |
| |
|
| | |
| | if self.do_classifier_free_guidance: |
| | |
| | ( |
| | _cross_attention_input_hidden_states_unconditional, |
| | cross_attention_input_hidden_states_conditional, |
| | ) = cross_attention_input_hidden_states.chunk(2) |
| |
|
| | |
| | |
| | _encoder_hidden_states_unconditional, encoder_hidden_states_conditional = self.encoder_hidden_states.chunk( |
| | 2 |
| | ) |
| | else: |
| | cross_attention_input_hidden_states_conditional = cross_attention_input_hidden_states |
| | encoder_hidden_states_conditional = self.encoder_hidden_states.chunk(2) |
| |
|
| | |
| | |
| | cross_attention_hidden_states = cross_attention_input_hidden_states_conditional |
| | cross_attention_encoder_hidden_states = encoder_hidden_states_conditional |
| |
|
| | attn2 = self.transformer_block.attn2 |
| |
|
| | if attn2.spatial_norm is not None: |
| | cross_attention_hidden_states = attn2.spatial_norm(cross_attention_hidden_states, temb) |
| |
|
| | input_ndim = cross_attention_hidden_states.ndim |
| |
|
| | if input_ndim == 4: |
| | batch_size, channel, height, width = cross_attention_hidden_states.shape |
| | cross_attention_hidden_states = cross_attention_hidden_states.view( |
| | batch_size, channel, height * width |
| | ).transpose(1, 2) |
| |
|
| | ( |
| | batch_size, |
| | sequence_length, |
| | _, |
| | ) = cross_attention_hidden_states.shape |
| | |
| | attention_mask = attn2.prepare_attention_mask( |
| | None, sequence_length, batch_size |
| | ) |
| |
|
| | if attn2.group_norm is not None: |
| | cross_attention_hidden_states = attn2.group_norm(cross_attention_hidden_states.transpose(1, 2)).transpose( |
| | 1, 2 |
| | ) |
| |
|
| | query2 = attn2.to_q(cross_attention_hidden_states) |
| |
|
| | if attn2.norm_cross: |
| | cross_attention_encoder_hidden_states = attn2.norm_encoder_hidden_states( |
| | cross_attention_encoder_hidden_states |
| | ) |
| |
|
| | key2 = attn2.to_k(cross_attention_encoder_hidden_states) |
| | query2 = attn2.head_to_batch_dim(query2) |
| | key2 = attn2.head_to_batch_dim(key2) |
| |
|
| | cross_attention_probs = attn2.get_attention_scores(query2, key2, attention_mask) |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | mask = (mask > 0.5).to(self_attention_output_hidden_states.dtype) |
| | m = mask.to(self_attention_output_hidden_states.device) |
| | |
| | m = m.permute(0, 2, 3, 1).reshape((m.shape[0], -1, m.shape[1])).contiguous() |
| | m = torch.matmul(m, m.permute(0, 2, 1)) + (1 - m) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | batch_size, dims, channels = cross_attention_probs.shape |
| | batch_size = batch_size // attn.heads |
| | cross_attention_probs = cross_attention_probs.reshape((batch_size, attn.heads, dims, channels)) |
| |
|
| | cross_attention_probs = cross_attention_probs.mean(dim=1) |
| | cross_attention_probs = cross_attention_probs[..., self.token_idx].sum(dim=-1) |
| | cross_attention_probs = cross_attention_probs.reshape((batch_size,) + shape) |
| |
|
| | gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).to( |
| | self_attention_output_hidden_states.device |
| | ) |
| | cross_attention_probs = gaussian_smoothing(cross_attention_probs[:, None])[:, 0] |
| |
|
| | |
| | cross_attention_probs = cross_attention_probs.reshape(batch_size, -1) |
| | cross_attention_probs = ( |
| | cross_attention_probs - cross_attention_probs.median(dim=-1, keepdim=True).values |
| | ) / cross_attention_probs.max(dim=-1, keepdim=True).values |
| | cross_attention_probs = cross_attention_probs.clip(0, 1) |
| |
|
| | c = (1 - m) * cross_attention_probs.reshape(batch_size, 1, -1) + m |
| | c = c.repeat_interleave(attn.heads, 0) |
| | if self.do_classifier_free_guidance: |
| | c = torch.cat([c, c]) |
| |
|
| | |
| | self_attention_scores_rescaled = self_attention_scores * c |
| | self_attention_probs_rescaled = self_attention_scores_rescaled.softmax(dim=-1) |
| |
|
| | |
| | hidden_states = torch.bmm(self_attention_probs_rescaled, value) |
| | hidden_states = attn.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | if input_ndim == 4: |
| | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
| |
|
| | if attn.residual_connection: |
| | hidden_states = hidden_states + input_hidden_states |
| |
|
| | hidden_states = hidden_states / attn.rescale_output_factor |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline): |
| | def get_tokenized_prompt(self, prompt): |
| | out = self.tokenizer(prompt) |
| | return [self.tokenizer.decode(x) for x in out["input_ids"]] |
| |
|
| | def init_attn_processors( |
| | self, |
| | mask, |
| | token_idx, |
| | use_painta=True, |
| | use_rasg=True, |
| | painta_scale_factors=[2, 4], |
| | rasg_scale_factor=4, |
| | self_attention_layer_name="attn1", |
| | cross_attention_layer_name="attn2", |
| | list_of_painta_layer_names=None, |
| | list_of_rasg_layer_names=None, |
| | ): |
| | default_processor = AttnProcessor() |
| | width, height = mask.shape[-2:] |
| | width, height = width // self.vae_scale_factor, height // self.vae_scale_factor |
| |
|
| | painta_scale_factors = [x * self.vae_scale_factor for x in painta_scale_factors] |
| | rasg_scale_factor = self.vae_scale_factor * rasg_scale_factor |
| |
|
| | attn_processors = {} |
| | for x in self.unet.attn_processors: |
| | if (list_of_painta_layer_names is None and self_attention_layer_name in x) or ( |
| | list_of_painta_layer_names is not None and x in list_of_painta_layer_names |
| | ): |
| | if use_painta: |
| | transformer_block = self.unet.get_submodule(x.replace(".attn1.processor", "")) |
| | attn_processors[x] = PAIntAAttnProcessor( |
| | transformer_block, mask, token_idx, self.do_classifier_free_guidance, painta_scale_factors |
| | ) |
| | else: |
| | attn_processors[x] = default_processor |
| | elif (list_of_rasg_layer_names is None and cross_attention_layer_name in x) or ( |
| | list_of_rasg_layer_names is not None and x in list_of_rasg_layer_names |
| | ): |
| | if use_rasg: |
| | attn_processors[x] = RASGAttnProcessor(mask, token_idx, rasg_scale_factor) |
| | else: |
| | attn_processors[x] = default_processor |
| |
|
| | self.unet.set_attn_processor(attn_processors) |
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | prompt: Union[str, List[str]] = None, |
| | image: PipelineImageInput = None, |
| | mask_image: PipelineImageInput = None, |
| | masked_image_latents: torch.Tensor = None, |
| | height: Optional[int] = None, |
| | width: Optional[int] = None, |
| | padding_mask_crop: Optional[int] = None, |
| | strength: float = 1.0, |
| | num_inference_steps: int = 50, |
| | timesteps: List[int] = None, |
| | guidance_scale: float = 7.5, |
| | positive_prompt: Optional[str] = "", |
| | negative_prompt: Optional[Union[str, List[str]]] = None, |
| | num_images_per_prompt: Optional[int] = 1, |
| | eta: float = 0.01, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | latents: Optional[torch.Tensor] = None, |
| | prompt_embeds: Optional[torch.Tensor] = None, |
| | negative_prompt_embeds: Optional[torch.Tensor] = None, |
| | ip_adapter_image: Optional[PipelineImageInput] = None, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | clip_skip: int = None, |
| | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| | callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| | use_painta=True, |
| | use_rasg=True, |
| | self_attention_layer_name=".attn1", |
| | cross_attention_layer_name=".attn2", |
| | painta_scale_factors=[2, 4], |
| | rasg_scale_factor=4, |
| | list_of_painta_layer_names=None, |
| | list_of_rasg_layer_names=None, |
| | **kwargs, |
| | ): |
| | callback = kwargs.pop("callback", None) |
| | callback_steps = kwargs.pop("callback_steps", None) |
| |
|
| | if callback is not None: |
| | deprecate( |
| | "callback", |
| | "1.0.0", |
| | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", |
| | ) |
| | if callback_steps is not None: |
| | deprecate( |
| | "callback_steps", |
| | "1.0.0", |
| | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", |
| | ) |
| |
|
| | |
| | height = height or self.unet.config.sample_size * self.vae_scale_factor |
| | width = width or self.unet.config.sample_size * self.vae_scale_factor |
| |
|
| | |
| | prompt_no_positives = prompt |
| | if isinstance(prompt, list): |
| | prompt = [x + positive_prompt for x in prompt] |
| | else: |
| | prompt = prompt + positive_prompt |
| |
|
| | |
| | self.check_inputs( |
| | prompt, |
| | image, |
| | mask_image, |
| | height, |
| | width, |
| | strength, |
| | callback_steps, |
| | negative_prompt, |
| | prompt_embeds, |
| | negative_prompt_embeds, |
| | callback_on_step_end_tensor_inputs, |
| | padding_mask_crop, |
| | ) |
| |
|
| | self._guidance_scale = guidance_scale |
| | self._clip_skip = clip_skip |
| | self._cross_attention_kwargs = cross_attention_kwargs |
| | self._interrupt = False |
| |
|
| | |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | |
| |
|
| | device = self._execution_device |
| |
|
| | |
| | text_encoder_lora_scale = ( |
| | cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None |
| | ) |
| | prompt_embeds, negative_prompt_embeds = self.encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | self.do_classifier_free_guidance, |
| | negative_prompt, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | lora_scale=text_encoder_lora_scale, |
| | clip_skip=self.clip_skip, |
| | ) |
| | |
| | |
| | |
| | if self.do_classifier_free_guidance: |
| | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
| |
|
| | if ip_adapter_image is not None: |
| | output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True |
| | image_embeds, negative_image_embeds = self.encode_image( |
| | ip_adapter_image, device, num_images_per_prompt, output_hidden_state |
| | ) |
| | if self.do_classifier_free_guidance: |
| | image_embeds = torch.cat([negative_image_embeds, image_embeds]) |
| |
|
| | |
| | timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) |
| | timesteps, num_inference_steps = self.get_timesteps( |
| | num_inference_steps=num_inference_steps, strength=strength, device=device |
| | ) |
| | |
| | if num_inference_steps < 1: |
| | raise ValueError( |
| | f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" |
| | f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." |
| | ) |
| | |
| | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
| | |
| | is_strength_max = strength == 1.0 |
| |
|
| | |
| |
|
| | if padding_mask_crop is not None: |
| | crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) |
| | resize_mode = "fill" |
| | else: |
| | crops_coords = None |
| | resize_mode = "default" |
| |
|
| | original_image = image |
| | init_image = self.image_processor.preprocess( |
| | image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode |
| | ) |
| | init_image = init_image.to(dtype=torch.float32) |
| |
|
| | |
| | num_channels_latents = self.vae.config.latent_channels |
| | num_channels_unet = self.unet.config.in_channels |
| | return_image_latents = num_channels_unet == 4 |
| |
|
| | latents_outputs = self.prepare_latents( |
| | batch_size * num_images_per_prompt, |
| | num_channels_latents, |
| | height, |
| | width, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | latents, |
| | image=init_image, |
| | timestep=latent_timestep, |
| | is_strength_max=is_strength_max, |
| | return_noise=True, |
| | return_image_latents=return_image_latents, |
| | ) |
| |
|
| | if return_image_latents: |
| | latents, noise, image_latents = latents_outputs |
| | else: |
| | latents, noise = latents_outputs |
| |
|
| | |
| | mask_condition = self.mask_processor.preprocess( |
| | mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords |
| | ) |
| |
|
| | if masked_image_latents is None: |
| | masked_image = init_image * (mask_condition < 0.5) |
| | else: |
| | masked_image = masked_image_latents |
| |
|
| | mask, masked_image_latents = self.prepare_mask_latents( |
| | mask_condition, |
| | masked_image, |
| | batch_size * num_images_per_prompt, |
| | height, |
| | width, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | self.do_classifier_free_guidance, |
| | ) |
| |
|
| | |
| |
|
| | |
| | token_idx = list(range(1, self.get_tokenized_prompt(prompt_no_positives).index("<|endoftext|>"))) + [ |
| | self.get_tokenized_prompt(prompt).index("<|endoftext|>") |
| | ] |
| |
|
| | |
| | self.init_attn_processors( |
| | mask_condition, |
| | token_idx, |
| | use_painta, |
| | use_rasg, |
| | painta_scale_factors=painta_scale_factors, |
| | rasg_scale_factor=rasg_scale_factor, |
| | self_attention_layer_name=self_attention_layer_name, |
| | cross_attention_layer_name=cross_attention_layer_name, |
| | list_of_painta_layer_names=list_of_painta_layer_names, |
| | list_of_rasg_layer_names=list_of_rasg_layer_names, |
| | ) |
| |
|
| | |
| | if num_channels_unet == 9: |
| | |
| | num_channels_mask = mask.shape[1] |
| | num_channels_masked_image = masked_image_latents.shape[1] |
| | if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: |
| | raise ValueError( |
| | f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" |
| | f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" |
| | f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" |
| | f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" |
| | " `pipeline.unet` or your `mask_image` or `image` input." |
| | ) |
| | elif num_channels_unet != 4: |
| | raise ValueError( |
| | f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." |
| | ) |
| |
|
| | |
| | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
| |
|
| | if use_rasg: |
| | extra_step_kwargs["generator"] = None |
| |
|
| | |
| | added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None |
| |
|
| | |
| | timestep_cond = None |
| | if self.unet.config.time_cond_proj_dim is not None: |
| | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) |
| | timestep_cond = self.get_guidance_scale_embedding( |
| | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim |
| | ).to(device=device, dtype=latents.dtype) |
| |
|
| | |
| | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| | self._num_timesteps = len(timesteps) |
| | painta_active = True |
| |
|
| | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| | for i, t in enumerate(timesteps): |
| | if self.interrupt: |
| | continue |
| |
|
| | if t < 500 and painta_active: |
| | self.init_attn_processors( |
| | mask_condition, |
| | token_idx, |
| | False, |
| | use_rasg, |
| | painta_scale_factors=painta_scale_factors, |
| | rasg_scale_factor=rasg_scale_factor, |
| | self_attention_layer_name=self_attention_layer_name, |
| | cross_attention_layer_name=cross_attention_layer_name, |
| | list_of_painta_layer_names=list_of_painta_layer_names, |
| | list_of_rasg_layer_names=list_of_rasg_layer_names, |
| | ) |
| | painta_active = False |
| |
|
| | with torch.enable_grad(): |
| | self.unet.zero_grad() |
| | latents = latents.detach() |
| | latents.requires_grad = True |
| |
|
| | |
| | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
| |
|
| | |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | if num_channels_unet == 9: |
| | latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) |
| |
|
| | self.scheduler.latents = latents |
| | self.encoder_hidden_states = prompt_embeds |
| | for attn_processor in self.unet.attn_processors.values(): |
| | attn_processor.encoder_hidden_states = prompt_embeds |
| |
|
| | |
| | noise_pred = self.unet( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=prompt_embeds, |
| | timestep_cond=timestep_cond, |
| | cross_attention_kwargs=self.cross_attention_kwargs, |
| | added_cond_kwargs=added_cond_kwargs, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | if self.do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | if use_rasg: |
| | |
| | _, _, height, width = mask_condition.shape |
| | scale_factor = self.vae_scale_factor * rasg_scale_factor |
| |
|
| | |
| | rasg_mask = F.interpolate( |
| | mask_condition, (height // scale_factor, width // scale_factor), mode="bicubic" |
| | )[0, 0] |
| |
|
| | |
| | attn_map = [] |
| | for processor in self.unet.attn_processors.values(): |
| | if hasattr(processor, "attention_scores") and processor.attention_scores is not None: |
| | if self.do_classifier_free_guidance: |
| | attn_map.append(processor.attention_scores.chunk(2)[1]) |
| | else: |
| | attn_map.append(processor.attention_scores) |
| |
|
| | attn_map = ( |
| | torch.cat(attn_map) |
| | .mean(0) |
| | .permute(1, 0) |
| | .reshape((-1, height // scale_factor, width // scale_factor)) |
| | ) |
| |
|
| | |
| | attn_score = -sum( |
| | [ |
| | F.binary_cross_entropy_with_logits(x - 1.0, rasg_mask.to(device)) |
| | for x in attn_map[token_idx] |
| | ] |
| | ) |
| |
|
| | |
| | attn_score.backward() |
| |
|
| | |
| | variance_noise = latents.grad.detach() |
| | |
| | variance_noise -= torch.mean(variance_noise, [1, 2, 3], keepdim=True) |
| | variance_noise /= torch.std(variance_noise, [1, 2, 3], keepdim=True) |
| | else: |
| | variance_noise = None |
| |
|
| | |
| | latents = self.scheduler.step( |
| | noise_pred, t, latents, **extra_step_kwargs, return_dict=False, variance_noise=variance_noise |
| | )[0] |
| |
|
| | if num_channels_unet == 4: |
| | init_latents_proper = image_latents |
| | if self.do_classifier_free_guidance: |
| | init_mask, _ = mask.chunk(2) |
| | else: |
| | init_mask = mask |
| |
|
| | if i < len(timesteps) - 1: |
| | noise_timestep = timesteps[i + 1] |
| | init_latents_proper = self.scheduler.add_noise( |
| | init_latents_proper, noise, torch.tensor([noise_timestep]) |
| | ) |
| |
|
| | latents = (1 - init_mask) * init_latents_proper + init_mask * latents |
| |
|
| | if callback_on_step_end is not None: |
| | callback_kwargs = {} |
| | for k in callback_on_step_end_tensor_inputs: |
| | callback_kwargs[k] = locals()[k] |
| | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
| |
|
| | latents = callback_outputs.pop("latents", latents) |
| | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
| | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
| | mask = callback_outputs.pop("mask", mask) |
| | masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) |
| |
|
| | |
| | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| | progress_bar.update() |
| | if callback is not None and i % callback_steps == 0: |
| | step_idx = i // getattr(self.scheduler, "order", 1) |
| | callback(step_idx, t, latents) |
| |
|
| | if not output_type == "latent": |
| | condition_kwargs = {} |
| | if isinstance(self.vae, AsymmetricAutoencoderKL): |
| | init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) |
| | init_image_condition = init_image.clone() |
| | init_image = self._encode_vae_image(init_image, generator=generator) |
| | mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) |
| | condition_kwargs = {"image": init_image_condition, "mask": mask_condition} |
| | image = self.vae.decode( |
| | latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs |
| | )[0] |
| | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
| | else: |
| | image = latents |
| | has_nsfw_concept = None |
| |
|
| | if has_nsfw_concept is None: |
| | do_denormalize = [True] * image.shape[0] |
| | else: |
| | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
| |
|
| | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
| |
|
| | if padding_mask_crop is not None: |
| | image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] |
| |
|
| | |
| | self.maybe_free_model_hooks() |
| |
|
| | if not return_dict: |
| | return (image, has_nsfw_concept) |
| |
|
| | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
| |
|
| |
|
| | |
| |
|
| |
|
| | class GaussianSmoothing(nn.Module): |
| | """ |
| | Apply gaussian smoothing on a |
| | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel |
| | in the input using a depthwise convolution. |
| | Arguments: |
| | channels (int, sequence): Number of channels of the input tensors. Output will |
| | have this number of channels as well. |
| | kernel_size (int, sequence): Size of the gaussian kernel. |
| | sigma (float, sequence): Standard deviation of the gaussian kernel. |
| | dim (int, optional): The number of dimensions of the data. |
| | Default value is 2 (spatial). |
| | """ |
| |
|
| | def __init__(self, channels, kernel_size, sigma, dim=2): |
| | super(GaussianSmoothing, self).__init__() |
| | if isinstance(kernel_size, numbers.Number): |
| | kernel_size = [kernel_size] * dim |
| | if isinstance(sigma, numbers.Number): |
| | sigma = [sigma] * dim |
| |
|
| | |
| | |
| | kernel = 1 |
| | meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) |
| | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): |
| | mean = (size - 1) / 2 |
| | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) |
| |
|
| | |
| | kernel = kernel / torch.sum(kernel) |
| |
|
| | |
| | kernel = kernel.view(1, 1, *kernel.size()) |
| | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) |
| |
|
| | self.register_buffer("weight", kernel) |
| | self.groups = channels |
| |
|
| | if dim == 1: |
| | self.conv = F.conv1d |
| | elif dim == 2: |
| | self.conv = F.conv2d |
| | elif dim == 3: |
| | self.conv = F.conv3d |
| | else: |
| | raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)) |
| |
|
| | def forward(self, input): |
| | """ |
| | Apply gaussian filter to input. |
| | Arguments: |
| | input (torch.Tensor): Input to apply gaussian filter on. |
| | Returns: |
| | filtered (torch.Tensor): Filtered output. |
| | """ |
| | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding="same") |
| |
|
| |
|
| | def get_attention_scores( |
| | self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None |
| | ) -> torch.Tensor: |
| | r""" |
| | Compute the attention scores. |
| | |
| | Args: |
| | query (`torch.Tensor`): The query tensor. |
| | key (`torch.Tensor`): The key tensor. |
| | attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. |
| | |
| | Returns: |
| | `torch.Tensor`: The attention probabilities/scores. |
| | """ |
| | if self.upcast_attention: |
| | query = query.float() |
| | key = key.float() |
| |
|
| | if attention_mask is None: |
| | baddbmm_input = torch.empty( |
| | query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device |
| | ) |
| | beta = 0 |
| | else: |
| | baddbmm_input = attention_mask |
| | beta = 1 |
| |
|
| | attention_scores = torch.baddbmm( |
| | baddbmm_input, |
| | query, |
| | key.transpose(-1, -2), |
| | beta=beta, |
| | alpha=self.scale, |
| | ) |
| | del baddbmm_input |
| |
|
| | if self.upcast_softmax: |
| | attention_scores = attention_scores.float() |
| |
|
| | return attention_scores |
| |
|