| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch |
| import torch.nn as nn |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...utils import logging |
| from ..attention import AttentionMixin, FeedForward |
| from ..attention_processor import Attention, CogVideoXAttnProcessor2_0 |
| from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed |
| from ..modeling_outputs import Transformer2DModelOutput |
| from ..modeling_utils import ModelMixin |
| from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class CogView3PlusTransformerBlock(nn.Module): |
| r""" |
| Transformer block used in [CogView](https://github.com/THUDM/CogView3) model. |
| |
| Args: |
| dim (`int`): |
| The number of channels in the input and output. |
| num_attention_heads (`int`): |
| The number of heads to use for multi-head attention. |
| attention_head_dim (`int`): |
| The number of channels in each head. |
| time_embed_dim (`int`): |
| The number of channels in timestep embedding. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int = 2560, |
| num_attention_heads: int = 64, |
| attention_head_dim: int = 40, |
| time_embed_dim: int = 512, |
| ): |
| super().__init__() |
|
|
| self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim) |
|
|
| self.attn1 = Attention( |
| query_dim=dim, |
| heads=num_attention_heads, |
| dim_head=attention_head_dim, |
| out_dim=dim, |
| bias=True, |
| qk_norm="layer_norm", |
| elementwise_affine=False, |
| eps=1e-6, |
| processor=CogVideoXAttnProcessor2_0(), |
| ) |
|
|
| self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) |
| self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) |
|
|
| self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| emb: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| text_seq_length = encoder_hidden_states.size(1) |
|
|
| |
| ( |
| norm_hidden_states, |
| gate_msa, |
| shift_mlp, |
| scale_mlp, |
| gate_mlp, |
| norm_encoder_hidden_states, |
| c_gate_msa, |
| c_shift_mlp, |
| c_scale_mlp, |
| c_gate_mlp, |
| ) = self.norm1(hidden_states, encoder_hidden_states, emb) |
|
|
| |
| attn_hidden_states, attn_encoder_hidden_states = self.attn1( |
| hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states |
| ) |
|
|
| hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states |
| encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states |
|
|
| |
| norm_hidden_states = self.norm2(hidden_states) |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
| norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) |
| norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] |
|
|
| |
| norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) |
| ff_output = self.ff(norm_hidden_states) |
|
|
| hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:] |
| encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length] |
|
|
| if hidden_states.dtype == torch.float16: |
| hidden_states = hidden_states.clip(-65504, 65504) |
| if encoder_hidden_states.dtype == torch.float16: |
| encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) |
| return hidden_states, encoder_hidden_states |
|
|
|
|
| class CogView3PlusTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin): |
| r""" |
| The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay |
| Diffusion](https://huggingface.co/papers/2403.05121). |
| |
| Args: |
| patch_size (`int`, defaults to `2`): |
| The size of the patches to use in the patch embedding layer. |
| in_channels (`int`, defaults to `16`): |
| The number of channels in the input. |
| num_layers (`int`, defaults to `30`): |
| The number of layers of Transformer blocks to use. |
| attention_head_dim (`int`, defaults to `40`): |
| The number of channels in each head. |
| num_attention_heads (`int`, defaults to `64`): |
| The number of heads to use for multi-head attention. |
| out_channels (`int`, defaults to `16`): |
| The number of channels in the output. |
| text_embed_dim (`int`, defaults to `4096`): |
| Input dimension of text embeddings from the text encoder. |
| time_embed_dim (`int`, defaults to `512`): |
| Output dimension of timestep embeddings. |
| condition_dim (`int`, defaults to `256`): |
| The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, |
| crop_coords). |
| pos_embed_max_size (`int`, defaults to `128`): |
| The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added |
| to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 |
| means that the maximum supported height and width for image generation is `128 * vae_scale_factor * |
| patch_size => 128 * 8 * 2 => 2048`. |
| sample_size (`int`, defaults to `128`): |
| The base resolution of input latents. If height/width is not provided during generation, this value is used |
| to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024` |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| _skip_layerwise_casting_patterns = ["patch_embed", "norm"] |
| _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"] |
|
|
| @register_to_config |
| def __init__( |
| self, |
| patch_size: int = 2, |
| in_channels: int = 16, |
| num_layers: int = 30, |
| attention_head_dim: int = 40, |
| num_attention_heads: int = 64, |
| out_channels: int = 16, |
| text_embed_dim: int = 4096, |
| time_embed_dim: int = 512, |
| condition_dim: int = 256, |
| pos_embed_max_size: int = 128, |
| sample_size: int = 128, |
| ): |
| super().__init__() |
| self.out_channels = out_channels |
| self.inner_dim = num_attention_heads * attention_head_dim |
|
|
| |
| |
| self.pooled_projection_dim = 3 * 2 * condition_dim |
|
|
| self.patch_embed = CogView3PlusPatchEmbed( |
| in_channels=in_channels, |
| hidden_size=self.inner_dim, |
| patch_size=patch_size, |
| text_hidden_size=text_embed_dim, |
| pos_embed_max_size=pos_embed_max_size, |
| ) |
|
|
| self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings( |
| embedding_dim=time_embed_dim, |
| condition_dim=condition_dim, |
| pooled_projection_dim=self.pooled_projection_dim, |
| timesteps_dim=self.inner_dim, |
| ) |
|
|
| self.transformer_blocks = nn.ModuleList( |
| [ |
| CogView3PlusTransformerBlock( |
| dim=self.inner_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.norm_out = AdaLayerNormContinuous( |
| embedding_dim=self.inner_dim, |
| conditioning_embedding_dim=time_embed_dim, |
| elementwise_affine=False, |
| eps=1e-6, |
| ) |
| self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) |
|
|
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| timestep: torch.LongTensor, |
| original_size: torch.Tensor, |
| target_size: torch.Tensor, |
| crop_coords: torch.Tensor, |
| return_dict: bool = True, |
| ) -> tuple[torch.Tensor] | Transformer2DModelOutput: |
| """ |
| The [`CogView3PlusTransformer2DModel`] forward method. |
| |
| Args: |
| hidden_states (`torch.Tensor`): |
| Input `hidden_states` of shape `(batch size, channel, height, width)`. |
| encoder_hidden_states (`torch.Tensor`): |
| Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape |
| `(batch_size, sequence_len, text_embed_dim)` |
| timestep (`torch.LongTensor`): |
| Used to indicate denoising step. |
| original_size (`torch.Tensor`): |
| CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of |
| [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
| target_size (`torch.Tensor`): |
| CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of |
| [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
| crop_coords (`torch.Tensor`): |
| CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of |
| [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain |
| tuple. |
| |
| Returns: |
| `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]: |
| The denoised latents using provided inputs as conditioning. |
| """ |
| height, width = hidden_states.shape[-2:] |
| text_seq_length = encoder_hidden_states.shape[1] |
|
|
| hidden_states = self.patch_embed( |
| hidden_states, encoder_hidden_states |
| ) |
| emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) |
|
|
| encoder_hidden_states = hidden_states[:, :text_seq_length] |
| hidden_states = hidden_states[:, text_seq_length:] |
|
|
| for index_block, block in enumerate(self.transformer_blocks): |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( |
| block, |
| hidden_states, |
| encoder_hidden_states, |
| emb, |
| ) |
| else: |
| hidden_states, encoder_hidden_states = block( |
| hidden_states=hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| emb=emb, |
| ) |
|
|
| hidden_states = self.norm_out(hidden_states, emb) |
| hidden_states = self.proj_out(hidden_states) |
|
|
| |
| patch_size = self.config.patch_size |
| height = height // patch_size |
| width = width // patch_size |
|
|
| hidden_states = hidden_states.reshape( |
| shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size) |
| ) |
| hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states) |
| output = hidden_states.reshape( |
| shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) |
| ) |
|
|
| if not return_dict: |
| return (output,) |
|
|
| return Transformer2DModelOutput(sample=output) |
|
|