| | import itertools |
| | from functools import partial |
| | from typing import Any, Dict, Tuple, Callable |
| | from typing import Union, Optional, List |
| |
|
| | import numpy as np |
| | import torch |
| | from diffusers import DPMSolverMultistepScheduler |
| | from diffusers import StableDiffusionPipeline, AutoencoderKL |
| | from diffusers import Transformer2DModel, ModelMixin, ConfigMixin, SchedulerMixin |
| | from diffusers import UNet2DConditionModel |
| | from diffusers.configuration_utils import register_to_config |
| | from diffusers.models.attention import BasicTransformerBlock |
| | from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D |
| | from diffusers.models.transformer_2d import Transformer2DModelOutput |
| | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput |
| | from diffusers.schedulers import KarrasDiffusionSchedulers |
| | from diffusers.utils import replace_example_docstring |
| | from torch import nn |
| | from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor |
| |
|
| |
|
| | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
| | """ |
| | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
| | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
| | """ |
| | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
| | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
| | |
| | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
| | |
| | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
| | return noise_cfg |
| |
|
| |
|
| | def custom_sort_order(obj): |
| | """ |
| | Key function for sorting order of execution in forward methods |
| | """ |
| | return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__) |
| |
|
| |
|
| | def squeeze_to_len_n_starting_from_index_i(n, i, timestep_spacing): |
| | """ |
| | :param timestep_spacing: the timestep_spacing array we want to squeeze |
| | :param n: the size of the squeezed array |
| | :param i: the index we start squeezing from |
| | :return: squeezed timestep_spacing |
| | Example: |
| | timesteps = np.array([967, 907, 846, 786, 725, 665, 604, 544, 484, 423, 363, 302, 242, 181, 121, 60]) (len=16) |
| | n = 10, i = 6 |
| | Expected: |
| | [967, 907, 846, 786, 725, 665, 4k, 3k, 2k, k], and if we define 665=5k => k = 133 |
| | """ |
| | assert i < n |
| | squeezed = np.flip(np.arange(n)) + 1 |
| | squeezed[:i] = timestep_spacing[:i] |
| | k = squeezed[i - 1] // (n - i + 1) |
| | squeezed[i:] *= k |
| |
|
| | return squeezed |
| |
|
| |
|
| | PREDEFINED_TIMESTEP_SQUEEZERS = { |
| | |
| | "10,6": partial(squeeze_to_len_n_starting_from_index_i, 10, 6), |
| | "11,7": partial(squeeze_to_len_n_starting_from_index_i, 11, 7), |
| | } |
| |
|
| | FlexibleUnetConfigurations = { |
| | |
| | "sample_size": 64, |
| | "temb_dim": 320 * 4, |
| | "resnet_eps": 1e-5, |
| | "resnet_act_fn": "silu", |
| | "num_attention_heads": 8, |
| | "cross_attention_dim": 768, |
| | |
| | "mix_block_in_forward": True, |
| | |
| | "down_blocks_in_channels": [320, 320, 640], |
| | "down_blocks_out_channels": [320, 640, 1280], |
| | "down_blocks_num_attentions": [0, 1, 3], |
| | "down_blocks_num_resnets": [2, 2, 1], |
| | "add_downsample": [True, True, False], |
| | |
| | "add_upsample_mid_block": None, |
| | "mid_num_resnets": 0, |
| | "mid_num_attentions": 0, |
| | |
| | "prev_output_channels": [1280, 1280, 640], |
| | "up_blocks_num_attentions": [5, 3, 0], |
| | "up_blocks_num_resnets": [2, 3, 3], |
| | "add_upsample": [True, True, False], |
| | } |
| |
|
| |
|
| | class SqueezedDPMSolverMultistepScheduler(DPMSolverMultistepScheduler, SchedulerMixin): |
| | """ |
| | This is a copy-paste from Diffuser's `DPMSolverMultistepScheduler`, with minor differences: |
| | * Defaults are modified to accommodate DeciDiffusion |
| | * It supports a squeezer to squeeze the number of inference steps to a smaller number |
| | //!\\ IMPORTANT: the actual number of inference steps is deduced by the squeezer, and not the pipeline! |
| | """ |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | num_train_timesteps: int = 1000, |
| | beta_start: float = 0.0001, |
| | beta_end: float = 0.02, |
| | beta_schedule: str = "squaredcos_cap_v2", |
| | trained_betas: Optional[Union[np.ndarray, List[float]]] = None, |
| | solver_order: int = 2, |
| | prediction_type: str = "v_prediction", |
| | thresholding: bool = False, |
| | dynamic_thresholding_ratio: float = 0.995, |
| | sample_max_value: float = 1.0, |
| | algorithm_type: str = "dpmsolver++", |
| | solver_type: str = "heun", |
| | lower_order_final: bool = True, |
| | use_karras_sigmas: Optional[bool] = False, |
| | lambda_min_clipped: float = -7.5, |
| | variance_type: Optional[str] = None, |
| | timestep_spacing: str = "linspace", |
| | steps_offset: int = 1, |
| | squeeze_mode: Optional[str] = None, |
| | ): |
| | self._squeezer = PREDEFINED_TIMESTEP_SQUEEZERS.get(squeeze_mode) |
| |
|
| | if use_karras_sigmas: |
| | raise NotImplementedError("Squeezing isn't tested with `use_karras_sigmas`. Please provide `use_karras_sigmas=False`") |
| |
|
| | super().__init__( |
| | num_train_timesteps=num_train_timesteps, |
| | beta_start=beta_start, |
| | beta_end=beta_end, |
| | beta_schedule=beta_schedule, |
| | trained_betas=trained_betas, |
| | solver_order=solver_order, |
| | prediction_type=prediction_type, |
| | thresholding=thresholding, |
| | dynamic_thresholding_ratio=dynamic_thresholding_ratio, |
| | sample_max_value=sample_max_value, |
| | algorithm_type=algorithm_type, |
| | solver_type=solver_type, |
| | lower_order_final=lower_order_final, |
| | use_karras_sigmas=False, |
| | lambda_min_clipped=lambda_min_clipped, |
| | variance_type=variance_type, |
| | timestep_spacing=timestep_spacing, |
| | steps_offset=steps_offset, |
| | ) |
| |
|
| | def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): |
| | """ |
| | Sets the discrete timesteps used for the diffusion chain (to be run before inference). |
| | |
| | Args: |
| | num_inference_steps (`int`): |
| | The number of diffusion steps used when generating samples with a pre-trained model. |
| | device (`str` or `torch.device`, *optional*): |
| | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| | """ |
| | super().set_timesteps(num_inference_steps=num_inference_steps, device=device) |
| | if self._squeezer is not None: |
| | timesteps = self._squeezer(self.timesteps.cpu()) |
| | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) |
| | sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) |
| | sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 |
| | sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) |
| | self.sigmas = torch.from_numpy(sigmas) |
| | self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) |
| | self.num_inference_steps = len(timesteps) |
| |
|
| |
|
| | class FlexibleIdentityBlock(nn.Module): |
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | temb: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | ): |
| | return hidden_states |
| |
|
| |
|
| | class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin): |
| | configurations = FlexibleUnetConfigurations |
| |
|
| | @register_to_config |
| | def __init__(self): |
| | super().__init__( |
| | sample_size=self.configurations.get("sample_size", FlexibleUnetConfigurations["sample_size"]), |
| | cross_attention_dim=self.configurations.get("cross_attention_dim", FlexibleUnetConfigurations["cross_attention_dim"]), |
| | ) |
| |
|
| | num_attention_heads = self.configurations.get("num_attention_heads") |
| | cross_attention_dim = self.configurations.get("cross_attention_dim") |
| | mix_block_in_forward = self.configurations.get("mix_block_in_forward") |
| | resnet_act_fn = self.configurations.get("resnet_act_fn") |
| | resnet_eps = self.configurations.get("resnet_eps") |
| | temb_dim = self.configurations.get("temb_dim") |
| |
|
| | |
| | |
| | |
| | down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions") |
| | down_blocks_out_channels = self.configurations.get("down_blocks_out_channels") |
| | down_blocks_in_channels = self.configurations.get("down_blocks_in_channels") |
| | down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets") |
| | add_downsample = self.configurations.get("add_downsample") |
| |
|
| | self.down_blocks = nn.ModuleList() |
| |
|
| | for i, (in_c, out_c, n_res, n_att, add_down) in enumerate( |
| | zip(down_blocks_in_channels, down_blocks_out_channels, down_blocks_num_resnets, down_blocks_num_attentions, add_downsample) |
| | ): |
| | last_block = i == len(down_blocks_in_channels) - 1 |
| | self.down_blocks.append( |
| | FlexibleCrossAttnDownBlock2D( |
| | in_channels=in_c, |
| | out_channels=out_c, |
| | temb_channels=temb_dim, |
| | num_resnets=n_res, |
| | num_attentions=n_att, |
| | resnet_eps=resnet_eps, |
| | resnet_act_fn=resnet_act_fn, |
| | num_attention_heads=num_attention_heads, |
| | cross_attention_dim=cross_attention_dim, |
| | add_downsample=add_down, |
| | last_block=last_block, |
| | mix_block_in_forward=mix_block_in_forward, |
| | ) |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | mid_block_add_upsample = self.configurations.get("add_upsample_mid_block") |
| | mid_num_attentions = self.configurations.get("mid_num_attentions") |
| | mid_num_resnets = self.configurations.get("mid_num_resnets") |
| |
|
| | if mid_num_resnets == mid_num_attentions == 0: |
| | self.mid_block = FlexibleIdentityBlock() |
| | else: |
| | self.mid_block = FlexibleUNetMidBlock2DCrossAttn( |
| | in_channels=down_blocks_out_channels[-1], |
| | temb_channels=temb_dim, |
| | resnet_act_fn=resnet_act_fn, |
| | resnet_eps=resnet_eps, |
| | cross_attention_dim=cross_attention_dim, |
| | num_attention_heads=num_attention_heads, |
| | num_resnets=mid_num_resnets, |
| | num_attentions=mid_num_attentions, |
| | mix_block_in_forward=mix_block_in_forward, |
| | add_upsample=mid_block_add_upsample, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions") |
| | up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets") |
| | prev_output_channels = self.configurations.get("prev_output_channels") |
| | up_upsample = self.configurations.get("add_upsample") |
| |
|
| | self.up_blocks = nn.ModuleList() |
| | for in_c, out_c, prev_out, n_res, n_att, add_up in zip( |
| | reversed(down_blocks_in_channels), |
| | reversed(down_blocks_out_channels), |
| | prev_output_channels, |
| | up_blocks_num_resnets, |
| | up_blocks_num_attentions, |
| | up_upsample, |
| | ): |
| | self.up_blocks.append( |
| | FlexibleCrossAttnUpBlock2D( |
| | in_channels=in_c, |
| | out_channels=out_c, |
| | prev_output_channel=prev_out, |
| | temb_channels=temb_dim, |
| | num_resnets=n_res, |
| | num_attentions=n_att, |
| | resnet_eps=resnet_eps, |
| | resnet_act_fn=resnet_act_fn, |
| | num_attention_heads=num_attention_heads, |
| | cross_attention_dim=cross_attention_dim, |
| | add_upsample=add_up, |
| | mix_block_in_forward=mix_block_in_forward, |
| | ) |
| | ) |
| |
|
| |
|
| | class FlexibleCrossAttnDownBlock2D(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | temb_channels: int, |
| | dropout: float = 0.0, |
| | num_resnets: int = 1, |
| | num_attentions: int = 1, |
| | transformer_layers_per_block: int = 1, |
| | resnet_eps: float = 1e-6, |
| | resnet_time_scale_shift: str = "default", |
| | resnet_act_fn: str = "swish", |
| | resnet_groups: int = 32, |
| | resnet_pre_norm: bool = True, |
| | num_attention_heads: int = 1, |
| | cross_attention_dim: int = 1280, |
| | output_scale_factor: float = 1.0, |
| | downsample_padding: int = 1, |
| | add_downsample: bool = True, |
| | use_linear_projection: bool = False, |
| | only_cross_attention: bool = False, |
| | upcast_attention: bool = False, |
| | last_block: bool = False, |
| | mix_block_in_forward: bool = True, |
| | ): |
| | super().__init__() |
| |
|
| | self.last_block = last_block |
| | self.mix_block_in_forward = mix_block_in_forward |
| | self.has_cross_attention = True |
| | self.num_attention_heads = num_attention_heads |
| |
|
| | modules = [] |
| |
|
| | add_resnets = [True] * num_resnets |
| | add_cross_attentions = [True] * num_attentions |
| | for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
| | in_channels = in_channels if i == 0 else out_channels |
| | if add_resnet: |
| | modules.append( |
| | ResnetBlock2D( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | temb_channels=temb_channels, |
| | eps=resnet_eps, |
| | groups=resnet_groups, |
| | dropout=dropout, |
| | time_embedding_norm=resnet_time_scale_shift, |
| | non_linearity=resnet_act_fn, |
| | output_scale_factor=output_scale_factor, |
| | pre_norm=resnet_pre_norm, |
| | ) |
| | ) |
| | if add_cross_attention: |
| | modules.append( |
| | FlexibleTransformer2DModel( |
| | num_attention_heads=num_attention_heads, |
| | attention_head_dim=out_channels // num_attention_heads, |
| | in_channels=out_channels, |
| | num_layers=transformer_layers_per_block, |
| | cross_attention_dim=cross_attention_dim, |
| | norm_num_groups=resnet_groups, |
| | use_linear_projection=use_linear_projection, |
| | only_cross_attention=only_cross_attention, |
| | upcast_attention=upcast_attention, |
| | ) |
| | ) |
| |
|
| | if not mix_block_in_forward: |
| | modules = sorted(modules, key=custom_sort_order) |
| |
|
| | self.modules_list = nn.ModuleList(modules) |
| |
|
| | if add_downsample: |
| | self.downsamplers = nn.ModuleList([Downsample2D(out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op")]) |
| | else: |
| | self.downsamplers = None |
| |
|
| | self.gradient_checkpointing = False |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | temb: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | ): |
| | output_states = () |
| |
|
| | for module in self.modules_list: |
| | if isinstance(module, ResnetBlock2D): |
| | hidden_states = module(hidden_states, temb) |
| | elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
| | hidden_states = module( |
| | hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | attention_mask=attention_mask, |
| | encoder_attention_mask=encoder_attention_mask, |
| | return_dict=False, |
| | )[0] |
| | else: |
| | raise ValueError(f"Got an unexpected module in modules list! {type(module)}") |
| | if isinstance(module, ResnetBlock2D): |
| | output_states = output_states + (hidden_states,) |
| |
|
| | if self.downsamplers is not None: |
| | for downsampler in self.downsamplers: |
| | hidden_states = downsampler(hidden_states) |
| |
|
| | if not self.last_block: |
| | output_states = output_states + (hidden_states,) |
| |
|
| | return hidden_states, output_states |
| |
|
| |
|
| | class FlexibleCrossAttnUpBlock2D(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | prev_output_channel: int, |
| | temb_channels: int, |
| | dropout: float = 0.0, |
| | num_resnets: int = 1, |
| | num_attentions: int = 1, |
| | transformer_layers_per_block: int = 1, |
| | resnet_eps: float = 1e-6, |
| | resnet_time_scale_shift: str = "default", |
| | resnet_act_fn: str = "swish", |
| | resnet_groups: int = 32, |
| | resnet_pre_norm: bool = True, |
| | num_attention_heads: int = 1, |
| | cross_attention_dim: int = 1280, |
| | output_scale_factor: float = 1.0, |
| | add_upsample: bool = True, |
| | use_linear_projection: bool = False, |
| | only_cross_attention: bool = False, |
| | upcast_attention: bool = False, |
| | mix_block_in_forward: bool = True, |
| | ): |
| | super().__init__() |
| | modules = [] |
| |
|
| | |
| | self.resnets = [] |
| |
|
| | self.has_cross_attention = True |
| | self.num_attention_heads = num_attention_heads |
| |
|
| | add_resnets = [True] * num_resnets |
| | add_cross_attentions = [True] * num_attentions |
| | for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
| | res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels |
| | resnet_in_channels = prev_output_channel if i == 0 else out_channels |
| |
|
| | if add_resnet: |
| | self.resnets += [True] |
| | modules.append( |
| | ResnetBlock2D( |
| | in_channels=resnet_in_channels + res_skip_channels, |
| | out_channels=out_channels, |
| | temb_channels=temb_channels, |
| | eps=resnet_eps, |
| | groups=resnet_groups, |
| | dropout=dropout, |
| | time_embedding_norm=resnet_time_scale_shift, |
| | non_linearity=resnet_act_fn, |
| | output_scale_factor=output_scale_factor, |
| | pre_norm=resnet_pre_norm, |
| | ) |
| | ) |
| | if add_cross_attention: |
| | modules.append( |
| | FlexibleTransformer2DModel( |
| | num_attention_heads, |
| | out_channels // num_attention_heads, |
| | in_channels=out_channels, |
| | num_layers=transformer_layers_per_block, |
| | cross_attention_dim=cross_attention_dim, |
| | norm_num_groups=resnet_groups, |
| | use_linear_projection=use_linear_projection, |
| | only_cross_attention=only_cross_attention, |
| | upcast_attention=upcast_attention, |
| | ) |
| | ) |
| |
|
| | if not mix_block_in_forward: |
| | modules = sorted(modules, key=custom_sort_order) |
| |
|
| | self.modules_list = nn.ModuleList(modules) |
| |
|
| | self.upsamplers = None |
| | if add_upsample: |
| | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) |
| |
|
| | self.gradient_checkpointing = False |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], |
| | temb: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | upsample_size: Optional[int] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | ): |
| |
|
| | for module in self.modules_list: |
| | if isinstance(module, ResnetBlock2D): |
| | res_hidden_states = res_hidden_states_tuple[-1] |
| | res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
| | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
| | hidden_states = module(hidden_states, temb) |
| | if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
| | hidden_states = module( |
| | hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | attention_mask=attention_mask, |
| | encoder_attention_mask=encoder_attention_mask, |
| | return_dict=False, |
| | )[0] |
| |
|
| | if self.upsamplers is not None: |
| | for upsampler in self.upsamplers: |
| | hidden_states = upsampler(hidden_states, upsample_size) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class FlexibleUNetMidBlock2DCrossAttn(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | temb_channels: int, |
| | dropout: float = 0.0, |
| | num_resnets: int = 1, |
| | num_attentions: int = 1, |
| | transformer_layers_per_block: int = 1, |
| | resnet_eps: float = 1e-6, |
| | resnet_time_scale_shift: str = "default", |
| | resnet_act_fn: str = "swish", |
| | resnet_groups: int = 32, |
| | resnet_pre_norm: bool = True, |
| | num_attention_heads: int = 1, |
| | output_scale_factor: float = 1.0, |
| | cross_attention_dim: int = 1280, |
| | use_linear_projection: bool = False, |
| | upcast_attention: bool = False, |
| | mix_block_in_forward: bool = True, |
| | add_upsample: bool = True, |
| | ): |
| | super().__init__() |
| |
|
| | self.has_cross_attention = True |
| | self.num_attention_heads = num_attention_heads |
| | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) |
| | |
| | modules = [ |
| | ResnetBlock2D( |
| | in_channels=in_channels, |
| | out_channels=in_channels, |
| | temb_channels=temb_channels, |
| | eps=resnet_eps, |
| | groups=resnet_groups, |
| | dropout=dropout, |
| | time_embedding_norm=resnet_time_scale_shift, |
| | non_linearity=resnet_act_fn, |
| | output_scale_factor=output_scale_factor, |
| | pre_norm=resnet_pre_norm, |
| | ) |
| | ] |
| |
|
| | add_resnets = [True] * num_resnets |
| | add_cross_attentions = [True] * num_attentions |
| | for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
| | if add_cross_attention: |
| | modules.append( |
| | FlexibleTransformer2DModel( |
| | num_attention_heads, |
| | in_channels // num_attention_heads, |
| | in_channels=in_channels, |
| | num_layers=transformer_layers_per_block, |
| | cross_attention_dim=cross_attention_dim, |
| | norm_num_groups=resnet_groups, |
| | use_linear_projection=use_linear_projection, |
| | upcast_attention=upcast_attention, |
| | ) |
| | ) |
| |
|
| | if add_resnet: |
| | modules.append( |
| | ResnetBlock2D( |
| | in_channels=in_channels, |
| | out_channels=in_channels, |
| | temb_channels=temb_channels, |
| | eps=resnet_eps, |
| | groups=resnet_groups, |
| | dropout=dropout, |
| | time_embedding_norm=resnet_time_scale_shift, |
| | non_linearity=resnet_act_fn, |
| | output_scale_factor=output_scale_factor, |
| | pre_norm=resnet_pre_norm, |
| | ) |
| | ) |
| | if not mix_block_in_forward: |
| | modules = sorted(modules, key=custom_sort_order) |
| |
|
| | self.modules_list = nn.ModuleList(modules) |
| |
|
| | self.upsamplers = nn.ModuleList([nn.Identity()]) |
| | if add_upsample: |
| | self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)]) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | temb: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | ) -> torch.FloatTensor: |
| | hidden_states = self.modules_list[0](hidden_states, temb) |
| |
|
| | for module in self.modules_list: |
| | if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
| | hidden_states = module( |
| | hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | attention_mask=attention_mask, |
| | encoder_attention_mask=encoder_attention_mask, |
| | return_dict=False, |
| | )[0] |
| | elif isinstance(module, ResnetBlock2D): |
| | hidden_states = module(hidden_states, temb) |
| |
|
| | for upsampler in self.upsamplers: |
| | hidden_states = upsampler(hidden_states) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class FlexibleTransformer2DModel(ModelMixin, ConfigMixin): |
| | @register_to_config |
| | def __init__( |
| | self, |
| | num_attention_heads: int = 16, |
| | attention_head_dim: int = 88, |
| | in_channels: Optional[int] = None, |
| | out_channels: Optional[int] = None, |
| | num_layers: int = 1, |
| | dropout: float = 0.0, |
| | norm_num_groups: int = 32, |
| | cross_attention_dim: Optional[int] = None, |
| | attention_bias: bool = False, |
| | activation_fn: str = "geglu", |
| | num_embeds_ada_norm: Optional[int] = None, |
| | only_cross_attention: bool = False, |
| | use_linear_projection: bool = False, |
| | upcast_attention: bool = False, |
| | norm_type: str = "layer_norm", |
| | norm_elementwise_affine: bool = True, |
| | ): |
| | super().__init__() |
| | self.num_attention_heads = num_attention_heads |
| | self.attention_head_dim = attention_head_dim |
| | self.in_channels = in_channels |
| | inner_dim = num_attention_heads * attention_head_dim |
| |
|
| | |
| | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) |
| | self.use_linear_projection = use_linear_projection |
| | if self.use_linear_projection: |
| | self.proj_in = nn.Linear(in_channels, inner_dim) |
| | else: |
| | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) |
| |
|
| | |
| | self.transformer_blocks = nn.ModuleList( |
| | [ |
| | BasicTransformerBlock( |
| | inner_dim, |
| | num_attention_heads, |
| | attention_head_dim, |
| | dropout=dropout, |
| | cross_attention_dim=cross_attention_dim, |
| | activation_fn=activation_fn, |
| | num_embeds_ada_norm=num_embeds_ada_norm, |
| | attention_bias=attention_bias, |
| | only_cross_attention=only_cross_attention, |
| | upcast_attention=upcast_attention, |
| | norm_type=norm_type, |
| | norm_elementwise_affine=norm_elementwise_affine, |
| | ) |
| | for _ in range(num_layers) |
| | ] |
| | ) |
| |
|
| | |
| | self.out_channels = in_channels if out_channels is None else out_channels |
| | if self.use_linear_projection: |
| | self.proj_out = nn.Linear(inner_dim, in_channels) |
| | else: |
| | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | encoder_hidden_states: Optional[torch.Tensor] = None, |
| | timestep: Optional[torch.LongTensor] = None, |
| | class_labels: Optional[torch.LongTensor] = None, |
| | cross_attention_kwargs: Dict[str, Any] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | encoder_attention_mask: Optional[torch.Tensor] = None, |
| | return_dict: bool = False, |
| | ): |
| | |
| | batch, _, height, width = hidden_states.shape |
| | residual = hidden_states |
| |
|
| | hidden_states = self.norm(hidden_states) |
| | if not self.use_linear_projection: |
| | hidden_states = self.proj_in(hidden_states) |
| | inner_dim = hidden_states.shape[1] |
| | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
| | else: |
| | inner_dim = hidden_states.shape[1] |
| | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
| | hidden_states = self.proj_in(hidden_states) |
| |
|
| | |
| | for block in self.transformer_blocks: |
| | hidden_states = block( |
| | hidden_states, |
| | attention_mask=attention_mask, |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=encoder_attention_mask, |
| | timestep=timestep, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | class_labels=class_labels, |
| | ) |
| |
|
| | |
| | if not self.use_linear_projection: |
| | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
| | hidden_states = self.proj_out(hidden_states) |
| | else: |
| | hidden_states = self.proj_out(hidden_states) |
| | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
| |
|
| | output = hidden_states + residual |
| | if return_dict: |
| | return (output,) |
| | return Transformer2DModelOutput(sample=output) |
| |
|
| |
|
| | class DeciDiffusionPipeline(StableDiffusionPipeline): |
| | deci_default_squeeze_mode = "10,6" |
| | deci_default_number_of_iterations = 16 |
| | deci_default_guidance_rescale = 0.8 |
| |
|
| | def __init__( |
| | self, |
| | vae: AutoencoderKL, |
| | text_encoder: CLIPTextModel, |
| | tokenizer: CLIPTokenizer, |
| | unet: UNet2DConditionModel, |
| | scheduler: KarrasDiffusionSchedulers, |
| | safety_checker: StableDiffusionSafetyChecker, |
| | feature_extractor: CLIPImageProcessor, |
| | requires_safety_checker: bool = True, |
| | ): |
| | |
| | del unet |
| | unet = FlexibleUNet2DConditionModel() |
| |
|
| | |
| | del scheduler |
| | scheduler = SqueezedDPMSolverMultistepScheduler(squeeze_mode=self.deci_default_squeeze_mode) |
| |
|
| | super().__init__( |
| | vae=vae, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | unet=unet, |
| | scheduler=scheduler, |
| | safety_checker=safety_checker, |
| | feature_extractor=feature_extractor, |
| | requires_safety_checker=requires_safety_checker, |
| | ) |
| |
|
| | self.register_modules( |
| | vae=vae, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | unet=unet, |
| | scheduler=scheduler, |
| | safety_checker=safety_checker, |
| | feature_extractor=feature_extractor, |
| | ) |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | prompt: Union[str, List[str]] = None, |
| | height: Optional[int] = None, |
| | width: Optional[int] = None, |
| | num_inference_steps: int = 16, |
| | guidance_scale: float = 7.5, |
| | negative_prompt: Optional[Union[str, List[str]]] = None, |
| | num_images_per_prompt: Optional[int] = 1, |
| | eta: float = 0.0, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | latents: Optional[torch.FloatTensor] = None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| | callback_steps: int = 1, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | guidance_rescale: float = 0.8, |
| | ): |
| | r""" |
| | The call function to the pipeline for generation. |
| | |
| | Args: |
| | prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. |
| | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
| | The height in pixels of the generated image. |
| | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
| | The width in pixels of the generated image. |
| | num_inference_steps (`int`, *optional*, defaults to 50): |
| | The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| | expense of slower inference. |
| | guidance_scale (`float`, *optional*, defaults to 7.5): |
| | A higher guidance scale value encourages the model to generate images closely linked to the text |
| | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. |
| | negative_prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts to guide what to not include in image generation. If not defined, you need to |
| | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). |
| | num_images_per_prompt (`int`, *optional*, defaults to 1): |
| | The number of images to generate per prompt. |
| | eta (`float`, *optional*, defaults to 0.0): |
| | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies |
| | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. |
| | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
| | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
| | generation deterministic. |
| | latents (`torch.FloatTensor`, *optional*): |
| | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image |
| | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
| | tensor is generated by sampling using the supplied random `generator`. |
| | prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not |
| | provided, text embeddings are generated from the `prompt` input argument. |
| | negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If |
| | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. |
| | output_type (`str`, *optional*, defaults to `"pil"`): |
| | The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
| | plain tuple. |
| | callback (`Callable`, *optional*): |
| | A function that calls every `callback_steps` steps during inference. The function is called with the |
| | following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
| | callback_steps (`int`, *optional*, defaults to 1): |
| | The frequency at which the `callback` function is called. If not specified, the callback is called at |
| | every step. |
| | cross_attention_kwargs (`dict`, *optional*): |
| | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in |
| | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
| | guidance_rescale (`float`, *optional*, defaults to 0.7): |
| | Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are |
| | Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when |
| | using zero terminal SNR. |
| | |
| | Examples: |
| | |
| | Returns: |
| | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
| | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, |
| | otherwise a `tuple` is returned where the first element is a list with the generated images and the |
| | second element is a list of `bool`s indicating whether the corresponding generated image contains |
| | "not-safe-for-work" (nsfw) content. |
| | """ |
| | |
| | height = height or self.unet.config.sample_size * self.vae_scale_factor |
| | width = width or self.unet.config.sample_size * self.vae_scale_factor |
| |
|
| | |
| | self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) |
| |
|
| | |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | device = self._execution_device |
| | |
| | |
| | |
| | do_classifier_free_guidance = guidance_scale > 1.0 |
| |
|
| | |
| | text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None |
| | prompt_embeds, negative_prompt_embeds = self.encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | lora_scale=text_encoder_lora_scale, |
| | ) |
| | |
| | |
| | |
| | if do_classifier_free_guidance: |
| | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
| |
|
| | |
| | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| | timesteps = self.scheduler.timesteps |
| |
|
| | |
| | num_channels_latents = self.unet.config.in_channels |
| | latents = self.prepare_latents( |
| | batch_size * num_images_per_prompt, |
| | num_channels_latents, |
| | height, |
| | width, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | latents, |
| | ) |
| |
|
| | |
| | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
| |
|
| | |
| | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| | with self.progress_bar(total=len(timesteps)) as progress_bar: |
| | for i, t in enumerate(timesteps): |
| | |
| | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | noise_pred = self.unet( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=prompt_embeds, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | return_dict=False, |
| | )[0] |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | if do_classifier_free_guidance and guidance_rescale > 0.0: |
| | |
| | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) |
| |
|
| | |
| | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
| |
|
| | |
| | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| | progress_bar.update() |
| | if callback is not None and i % callback_steps == 0: |
| | callback(i, t, latents) |
| |
|
| | if not output_type == "latent": |
| | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
| | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
| | else: |
| | image = latents |
| | has_nsfw_concept = None |
| |
|
| | if has_nsfw_concept is None: |
| | do_denormalize = [True] * image.shape[0] |
| | else: |
| | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
| |
|
| | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
| |
|
| | |
| | self.maybe_free_model_hooks() |
| |
|
| | if not return_dict: |
| | return (image, has_nsfw_concept) |
| |
|
| | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
| |
|