| | import itertools |
| | from typing import Any, Optional, Dict, Tuple |
| |
|
| | import torch |
| | from diffusers import StableDiffusionPipeline, AutoencoderKL |
| | from diffusers import Transformer2DModel, ModelMixin, ConfigMixin |
| | from diffusers import UNet2DConditionModel |
| | from diffusers.configuration_utils import register_to_config |
| | from diffusers.models.attention import BasicTransformerBlock |
| | from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D |
| | try: |
| | from diffusers.models.transformer_2d import Transformer2DModelOutput |
| | except ImportError: |
| | from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput |
| | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
| | from diffusers.schedulers import KarrasDiffusionSchedulers |
| | from torch import nn |
| | from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor |
| |
|
| | FlexibleUnetConfigurations = { |
| | |
| | 'sample_size': 64, |
| | 'temb_dim': 320 * 4, |
| | 'resnet_eps': 1e-5, |
| | 'resnet_act_fn': 'silu', |
| | 'num_attention_heads': 8, |
| | 'cross_attention_dim': 768, |
| |
|
| | |
| | 'mix_block_in_forward': True, |
| | |
| | 'down_blocks_in_channels': [320, 320, 640], |
| | 'down_blocks_out_channels': [320, 640, 1280], |
| | 'down_blocks_num_attentions': [0, 1, 3], |
| | 'down_blocks_num_resnets': [2, 2, 1], |
| | 'add_downsample': [True, True, True], |
| |
|
| | |
| | 'add_upsample_mid_block': True, |
| | 'mid_num_resnets': 4, |
| | 'mid_num_attentions': 2, |
| |
|
| | |
| | 'prev_output_channels': [1280, 1280, 640], |
| | 'up_blocks_num_attentions': [6, 3, 0], |
| | 'up_blocks_num_resnets': [2, 3, 3], |
| | 'add_upsample': [True, True, False], |
| | } |
| |
|
| |
|
| | def custom_sort_order(obj): |
| | """ |
| | Key function for sorting order of execution in forward methods |
| | """ |
| | return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__) |
| |
|
| |
|
| | class FlexibleIdentityBlock(nn.Module): |
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | temb: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | ): |
| | return hidden_states |
| |
|
| |
|
| | class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin): |
| | configurations = FlexibleUnetConfigurations |
| |
|
| | @register_to_config |
| | def __init__(self): |
| | super().__init__(sample_size=self.configurations.get('sample_size', |
| | FlexibleUnetConfigurations['sample_size']), |
| | cross_attention_dim=self.configurations.get("cross_attention_dim", |
| | FlexibleUnetConfigurations['cross_attention_dim'])) |
| |
|
| | num_attention_heads = self.configurations.get("num_attention_heads") |
| | cross_attention_dim = self.configurations.get("cross_attention_dim") |
| | mix_block_in_forward = self.configurations.get("mix_block_in_forward") |
| | resnet_act_fn = self.configurations.get("resnet_act_fn") |
| | resnet_eps = self.configurations.get("resnet_eps") |
| | temb_dim = self.configurations.get("temb_dim") |
| |
|
| | |
| | |
| | |
| | down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions") |
| | down_blocks_out_channels = self.configurations.get("down_blocks_out_channels") |
| | down_blocks_in_channels = self.configurations.get("down_blocks_in_channels") |
| | down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets") |
| | add_downsample = self.configurations.get("add_downsample") |
| |
|
| | self.down_blocks = nn.ModuleList() |
| |
|
| | for i, (in_c, out_c, n_res, n_att, add_down) in enumerate(zip(down_blocks_in_channels, down_blocks_out_channels, |
| | down_blocks_num_resnets, |
| | down_blocks_num_attentions, |
| | add_downsample)): |
| | last_block = i == len(down_blocks_in_channels) - 1 |
| | self.down_blocks.append(FlexibleCrossAttnDownBlock2D(in_channels=in_c, |
| | out_channels=out_c, |
| | temb_channels=temb_dim, |
| | num_resnets=n_res, |
| | num_attentions=n_att, |
| | resnet_eps=resnet_eps, |
| | resnet_act_fn=resnet_act_fn, |
| | num_attention_heads=num_attention_heads, |
| | cross_attention_dim=cross_attention_dim, |
| | add_downsample=add_down, |
| | last_block=last_block, |
| | mix_block_in_forward=mix_block_in_forward)) |
| |
|
| | |
| | |
| | |
| |
|
| | mid_block_add_upsample = self.configurations.get("add_upsample_mid_block") |
| | mid_num_attentions = self.configurations.get("mid_num_attentions") |
| | mid_num_resnets = self.configurations.get("mid_num_resnets") |
| | |
| | if mid_num_resnets == mid_num_attentions == 0: |
| | self.mid_block = FlexibleIdentityBlock() |
| | else: |
| | self.mid_block = FlexibleUNetMidBlock2DCrossAttn(in_channels=down_blocks_out_channels[-1], |
| | temb_channels=temb_dim, |
| | resnet_act_fn=resnet_act_fn, |
| | resnet_eps=resnet_eps, |
| | cross_attention_dim=cross_attention_dim, |
| | num_attention_heads=num_attention_heads, |
| | num_resnets=mid_num_resnets, |
| | num_attentions=mid_num_attentions, |
| | mix_block_in_forward=mix_block_in_forward, |
| | add_upsample=mid_block_add_upsample |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions") |
| | up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets") |
| | prev_output_channels = self.configurations.get("prev_output_channels") |
| | up_upsample = self.configurations.get("add_upsample") |
| |
|
| | self.up_blocks = nn.ModuleList() |
| | for in_c, out_c, prev_out, n_res, n_att, add_up in zip(reversed(down_blocks_in_channels), |
| | reversed(down_blocks_out_channels), |
| | prev_output_channels, |
| | up_blocks_num_resnets, up_blocks_num_attentions, |
| | up_upsample): |
| | self.up_blocks.append(FlexibleCrossAttnUpBlock2D(in_channels=in_c, |
| | out_channels=out_c, |
| | prev_output_channel=prev_out, |
| | temb_channels=temb_dim, |
| | num_resnets=n_res, |
| | num_attentions=n_att, |
| | resnet_eps=resnet_eps, |
| | resnet_act_fn=resnet_act_fn, |
| | num_attention_heads=num_attention_heads, |
| | cross_attention_dim=cross_attention_dim, |
| | add_upsample=add_up, |
| | mix_block_in_forward=mix_block_in_forward |
| | )) |
| |
|
| |
|
| | class FlexibleCrossAttnDownBlock2D(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | temb_channels: int, |
| | dropout: float = 0.0, |
| | num_resnets: int = 1, |
| | num_attentions: int = 1, |
| | transformer_layers_per_block: int = 1, |
| | resnet_eps: float = 1e-6, |
| | resnet_time_scale_shift: str = "default", |
| | resnet_act_fn: str = "swish", |
| | resnet_groups: int = 32, |
| | resnet_pre_norm: bool = True, |
| | num_attention_heads: int = 1, |
| | cross_attention_dim: int = 1280, |
| | output_scale_factor: float = 1.0, |
| | downsample_padding: int = 1, |
| | add_downsample: bool = True, |
| | use_linear_projection: bool = False, |
| | only_cross_attention: bool = False, |
| | upcast_attention: bool = False, |
| | last_block: bool = False, |
| | mix_block_in_forward: bool = True, |
| | ): |
| | super().__init__() |
| |
|
| | self.last_block = last_block |
| | self.mix_block_in_forward = mix_block_in_forward |
| | self.has_cross_attention = True |
| | self.num_attention_heads = num_attention_heads |
| |
|
| | modules = [] |
| |
|
| | add_resnets = [True] * num_resnets |
| | add_cross_attentions = [True] * num_attentions |
| | for i, (add_resnet, add_cross_attention) in enumerate( |
| | itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
| | in_channels = in_channels if i == 0 else out_channels |
| | if add_resnet: |
| | modules.append( |
| | ResnetBlock2D( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | temb_channels=temb_channels, |
| | eps=resnet_eps, |
| | groups=resnet_groups, |
| | dropout=dropout, |
| | time_embedding_norm=resnet_time_scale_shift, |
| | non_linearity=resnet_act_fn, |
| | output_scale_factor=output_scale_factor, |
| | pre_norm=resnet_pre_norm, |
| | ) |
| | ) |
| | if add_cross_attention: |
| | modules.append( |
| | FlexibleTransformer2DModel( |
| | num_attention_heads=num_attention_heads, |
| | attention_head_dim=out_channels // num_attention_heads, |
| | in_channels=out_channels, |
| | num_layers=transformer_layers_per_block, |
| | cross_attention_dim=cross_attention_dim, |
| | norm_num_groups=resnet_groups, |
| | use_linear_projection=use_linear_projection, |
| | only_cross_attention=only_cross_attention, |
| | upcast_attention=upcast_attention, |
| | ) |
| | ) |
| |
|
| | if not mix_block_in_forward: |
| | modules = sorted(modules, key=custom_sort_order) |
| |
|
| | self.modules_list = nn.ModuleList(modules) |
| |
|
| | if add_downsample: |
| | self.downsamplers = nn.ModuleList( |
| | [ |
| | Downsample2D( |
| | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" |
| | ) |
| | ] |
| | ) |
| | else: |
| | self.downsamplers = None |
| |
|
| | self.gradient_checkpointing = False |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | temb: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | ): |
| | output_states = () |
| |
|
| | for module in self.modules_list: |
| | if isinstance(module, ResnetBlock2D): |
| | hidden_states = module(hidden_states, temb) |
| | elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
| | hidden_states = module( |
| | hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | attention_mask=attention_mask, |
| | encoder_attention_mask=encoder_attention_mask, |
| | return_dict=False, |
| | )[0] |
| | else: |
| | raise ValueError(f'Got an unexpected module in modules list! {type(module)}') |
| | if isinstance(module, ResnetBlock2D): |
| | output_states = output_states + (hidden_states,) |
| |
|
| | if self.downsamplers is not None: |
| | for downsampler in self.downsamplers: |
| | hidden_states = downsampler(hidden_states) |
| |
|
| | if not self.last_block: |
| | output_states = output_states + (hidden_states,) |
| |
|
| | return hidden_states, output_states |
| |
|
| |
|
| | class FlexibleCrossAttnUpBlock2D(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | prev_output_channel: int, |
| | temb_channels: int, |
| | dropout: float = 0.0, |
| | num_resnets: int = 1, |
| | num_attentions: int = 1, |
| | transformer_layers_per_block: int = 1, |
| | resnet_eps: float = 1e-6, |
| | resnet_time_scale_shift: str = "default", |
| | resnet_act_fn: str = "swish", |
| | resnet_groups: int = 32, |
| | resnet_pre_norm: bool = True, |
| | num_attention_heads: int = 1, |
| | cross_attention_dim: int = 1280, |
| | output_scale_factor: float = 1.0, |
| | add_upsample: bool = True, |
| | use_linear_projection: bool = False, |
| | only_cross_attention: bool = False, |
| | upcast_attention: bool = False, |
| | mix_block_in_forward: bool = True |
| | ): |
| | super().__init__() |
| | modules = [] |
| |
|
| | |
| | self.resnets = [] |
| |
|
| | self.has_cross_attention = True |
| | self.num_attention_heads = num_attention_heads |
| |
|
| | add_resnets = [True] * num_resnets |
| | add_cross_attentions = [True] * num_attentions |
| | for i, (add_resnet, add_cross_attention) in enumerate( |
| | itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
| | res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels |
| | resnet_in_channels = prev_output_channel if i == 0 else out_channels |
| |
|
| | if add_resnet: |
| | self.resnets += [True] |
| | modules.append( |
| | ResnetBlock2D( |
| | in_channels=resnet_in_channels + res_skip_channels, |
| | out_channels=out_channels, |
| | temb_channels=temb_channels, |
| | eps=resnet_eps, |
| | groups=resnet_groups, |
| | dropout=dropout, |
| | time_embedding_norm=resnet_time_scale_shift, |
| | non_linearity=resnet_act_fn, |
| | output_scale_factor=output_scale_factor, |
| | pre_norm=resnet_pre_norm, |
| | ) |
| | ) |
| | if add_cross_attention: |
| | modules.append( |
| | FlexibleTransformer2DModel( |
| | num_attention_heads, |
| | out_channels // num_attention_heads, |
| | in_channels=out_channels, |
| | num_layers=transformer_layers_per_block, |
| | cross_attention_dim=cross_attention_dim, |
| | norm_num_groups=resnet_groups, |
| | use_linear_projection=use_linear_projection, |
| | only_cross_attention=only_cross_attention, |
| | upcast_attention=upcast_attention, |
| | ) |
| | ) |
| |
|
| | if not mix_block_in_forward: |
| | modules = sorted(modules, key=custom_sort_order) |
| |
|
| | self.modules_list = nn.ModuleList(modules) |
| |
|
| | self.upsamplers = None |
| | if add_upsample: |
| | self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) |
| |
|
| | self.gradient_checkpointing = False |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], |
| | temb: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | upsample_size: Optional[int] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | ): |
| |
|
| | for module in self.modules_list: |
| | if isinstance(module, ResnetBlock2D): |
| | res_hidden_states = res_hidden_states_tuple[-1] |
| | res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
| | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
| | hidden_states = module(hidden_states, temb) |
| | if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
| | hidden_states = module( |
| | hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | attention_mask=attention_mask, |
| | encoder_attention_mask=encoder_attention_mask, |
| | return_dict=False, |
| | )[0] |
| |
|
| | if self.upsamplers is not None: |
| | for upsampler in self.upsamplers: |
| | hidden_states = upsampler(hidden_states, upsample_size) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class FlexibleUNetMidBlock2DCrossAttn(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | temb_channels: int, |
| | dropout: float = 0.0, |
| | num_resnets: int = 1, |
| | num_attentions: int = 1, |
| | transformer_layers_per_block: int = 1, |
| | resnet_eps: float = 1e-6, |
| | resnet_time_scale_shift: str = "default", |
| | resnet_act_fn: str = "swish", |
| | resnet_groups: int = 32, |
| | resnet_pre_norm: bool = True, |
| | num_attention_heads: int = 1, |
| | output_scale_factor: float = 1.0, |
| | cross_attention_dim: int = 1280, |
| | use_linear_projection: bool = False, |
| | upcast_attention: bool = False, |
| | mix_block_in_forward: bool = True, |
| | add_upsample: bool = True, |
| | ): |
| | super().__init__() |
| |
|
| | self.has_cross_attention = True |
| | self.num_attention_heads = num_attention_heads |
| | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) |
| | |
| | modules = [ResnetBlock2D( |
| | in_channels=in_channels, |
| | out_channels=in_channels, |
| | temb_channels=temb_channels, |
| | eps=resnet_eps, |
| | groups=resnet_groups, |
| | dropout=dropout, |
| | time_embedding_norm=resnet_time_scale_shift, |
| | non_linearity=resnet_act_fn, |
| | output_scale_factor=output_scale_factor, |
| | pre_norm=resnet_pre_norm, |
| | )] |
| |
|
| | add_resnets = [True] * num_resnets |
| | add_cross_attentions = [True] * num_attentions |
| | for i, (add_resnet, add_cross_attention) in enumerate( |
| | itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
| | if add_cross_attention: |
| | modules.append( |
| | FlexibleTransformer2DModel( |
| | num_attention_heads, |
| | in_channels // num_attention_heads, |
| | in_channels=in_channels, |
| | num_layers=transformer_layers_per_block, |
| | cross_attention_dim=cross_attention_dim, |
| | norm_num_groups=resnet_groups, |
| | use_linear_projection=use_linear_projection, |
| | upcast_attention=upcast_attention, |
| | ) |
| | ) |
| |
|
| | if add_resnet: |
| | modules.append( |
| | ResnetBlock2D( |
| | in_channels=in_channels, |
| | out_channels=in_channels, |
| | temb_channels=temb_channels, |
| | eps=resnet_eps, |
| | groups=resnet_groups, |
| | dropout=dropout, |
| | time_embedding_norm=resnet_time_scale_shift, |
| | non_linearity=resnet_act_fn, |
| | output_scale_factor=output_scale_factor, |
| | pre_norm=resnet_pre_norm, |
| | ) |
| | ) |
| | if not mix_block_in_forward: |
| | modules = sorted(modules, key=custom_sort_order) |
| |
|
| | self.modules_list = nn.ModuleList(modules) |
| |
|
| | self.upsamplers = nn.ModuleList([nn.Identity()]) |
| | if add_upsample: |
| | self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)]) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | temb: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | ) -> torch.FloatTensor: |
| | hidden_states = self.modules_list[0](hidden_states, temb) |
| |
|
| | for module in self.modules_list: |
| | if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
| | hidden_states = module( |
| | hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | attention_mask=attention_mask, |
| | encoder_attention_mask=encoder_attention_mask, |
| | return_dict=False, |
| | )[0] |
| | elif isinstance(module, ResnetBlock2D): |
| | hidden_states = module(hidden_states, temb) |
| |
|
| | for upsampler in self.upsamplers: |
| | hidden_states = upsampler(hidden_states) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class FlexibleTransformer2DModel(ModelMixin, ConfigMixin): |
| | @register_to_config |
| | def __init__( |
| | self, |
| | num_attention_heads: int = 16, |
| | attention_head_dim: int = 88, |
| | in_channels: Optional[int] = None, |
| | out_channels: Optional[int] = None, |
| | num_layers: int = 1, |
| | dropout: float = 0.0, |
| | norm_num_groups: int = 32, |
| | cross_attention_dim: Optional[int] = None, |
| | attention_bias: bool = False, |
| | activation_fn: str = "geglu", |
| | num_embeds_ada_norm: Optional[int] = None, |
| | only_cross_attention: bool = False, |
| | use_linear_projection: bool = False, |
| | upcast_attention: bool = False, |
| | norm_type: str = "layer_norm", |
| | norm_elementwise_affine: bool = True, |
| | ): |
| | super().__init__() |
| | self.num_attention_heads = num_attention_heads |
| | self.attention_head_dim = attention_head_dim |
| | self.in_channels = in_channels |
| | inner_dim = num_attention_heads * attention_head_dim |
| |
|
| | |
| | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) |
| | self.use_linear_projection = use_linear_projection |
| | if self.use_linear_projection: |
| | self.proj_in = nn.Linear(in_channels, inner_dim) |
| | else: |
| | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) |
| |
|
| | |
| | self.transformer_blocks = nn.ModuleList( |
| | [ |
| | BasicTransformerBlock( |
| | inner_dim, |
| | num_attention_heads, |
| | attention_head_dim, |
| | dropout=dropout, |
| | cross_attention_dim=cross_attention_dim, |
| | activation_fn=activation_fn, |
| | num_embeds_ada_norm=num_embeds_ada_norm, |
| | attention_bias=attention_bias, |
| | only_cross_attention=only_cross_attention, |
| | upcast_attention=upcast_attention, |
| | norm_type=norm_type, |
| | norm_elementwise_affine=norm_elementwise_affine, |
| | ) |
| | for _ in range(num_layers) |
| | ] |
| | ) |
| |
|
| | |
| | self.out_channels = in_channels if out_channels is None else out_channels |
| | if self.use_linear_projection: |
| | self.proj_out = nn.Linear(inner_dim, in_channels) |
| | else: |
| | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | encoder_hidden_states: Optional[torch.Tensor] = None, |
| | timestep: Optional[torch.LongTensor] = None, |
| | class_labels: Optional[torch.LongTensor] = None, |
| | cross_attention_kwargs: Dict[str, Any] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | encoder_attention_mask: Optional[torch.Tensor] = None, |
| | return_dict: bool = False |
| | ): |
| | |
| | batch, _, height, width = hidden_states.shape |
| | residual = hidden_states |
| |
|
| | hidden_states = self.norm(hidden_states) |
| | if not self.use_linear_projection: |
| | hidden_states = self.proj_in(hidden_states) |
| | inner_dim = hidden_states.shape[1] |
| | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
| | else: |
| | inner_dim = hidden_states.shape[1] |
| | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
| | hidden_states = self.proj_in(hidden_states) |
| |
|
| | |
| | for block in self.transformer_blocks: |
| | hidden_states = block( |
| | hidden_states, |
| | attention_mask=attention_mask, |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=encoder_attention_mask, |
| | timestep=timestep, |
| | cross_attention_kwargs=cross_attention_kwargs, |
| | class_labels=class_labels, |
| | ) |
| |
|
| | |
| | if not self.use_linear_projection: |
| | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
| | hidden_states = self.proj_out(hidden_states) |
| | else: |
| | hidden_states = self.proj_out(hidden_states) |
| | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
| |
|
| | output = hidden_states + residual |
| | if return_dict: |
| | return (output,) |
| | return Transformer2DModelOutput(sample=output) |
| |
|
| |
|
| | class DeciDiffusionPipeline(StableDiffusionPipeline): |
| | deci_default_number_of_iterations = 30 |
| | deci_default_guidance_rescale = 0.7 |
| |
|
| | def __init__(self, |
| | vae: AutoencoderKL, |
| | text_encoder: CLIPTextModel, |
| | tokenizer: CLIPTokenizer, |
| | unet: UNet2DConditionModel, |
| | scheduler: KarrasDiffusionSchedulers, |
| | safety_checker: StableDiffusionSafetyChecker, |
| | feature_extractor: CLIPImageProcessor, |
| | requires_safety_checker: bool = True |
| | ): |
| | |
| | del unet |
| | unet = FlexibleUNet2DConditionModel() |
| |
|
| | super().__init__(vae=vae, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | unet=unet, |
| | scheduler=scheduler, |
| | safety_checker=safety_checker, |
| | feature_extractor=feature_extractor, |
| | requires_safety_checker=requires_safety_checker |
| | ) |
| |
|
| | self.register_modules(vae=vae, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | unet=unet, |
| | scheduler=scheduler, |
| | safety_checker=safety_checker, |
| | feature_extractor=feature_extractor) |
| |
|
| | def __call__(self, *args, **kwargs): |
| | |
| | if "guidance_rescale" not in kwargs: |
| | kwargs.update({'guidance_rescale': self.deci_default_guidance_rescale}) |
| | if "num_inference_steps" not in kwargs: |
| | kwargs.update({'num_inference_steps': self.deci_default_number_of_iterations}) |
| | return super().__call__(*args, **kwargs) |
| |
|