| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Any, Dict, Optional, Tuple, Union |
| |
|
| | import os |
| | import json |
| | import torch |
| | import glob |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers.models.embeddings import CogVideoXPatchEmbed |
| | from diffusers.utils import is_torch_version, logging |
| | from diffusers.utils.torch_utils import maybe_allow_in_graph |
| | from diffusers.models.attention import Attention, FeedForward |
| | from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 |
| | from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed, get_2d_sincos_pos_embed |
| | from diffusers.models.modeling_outputs import Transformer2DModelOutput |
| | from diffusers.models.modeling_utils import ModelMixin |
| | from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero |
| |
|
| | from einops import rearrange |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | def zero_module(module): |
| | """ |
| | Zero out the parameters of a module and return it. |
| | """ |
| | for p in module.parameters(): |
| | p.detach().zero_() |
| | return module |
| |
|
| | class FloatGroupNorm(nn.GroupNorm): |
| | def forward(self, x): |
| | return super().forward(x.to(self.bias.dtype)).type(x.dtype) |
| | |
| | class CogVideoXPatchEmbed(nn.Module): |
| | def __init__( |
| | self, |
| | patch_size: int = 2, |
| | patch_size_t: Optional[int] = None, |
| | in_channels: int = 16, |
| | embed_dim: int = 1920, |
| | text_embed_dim: int = 4096, |
| | bias: bool = True, |
| | sample_width: int = 90, |
| | sample_height: int = 60, |
| | sample_frames: int = 49, |
| | temporal_compression_ratio: int = 4, |
| | max_text_seq_length: int = 226, |
| | spatial_interpolation_scale: float = 1.875, |
| | temporal_interpolation_scale: float = 1.0, |
| | use_positional_embeddings: bool = True, |
| | use_learned_positional_embeddings: bool = True, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | post_patch_height = sample_height // patch_size |
| | post_patch_width = sample_width // patch_size |
| | post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 |
| | self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames |
| | self.post_patch_height = post_patch_height |
| | self.post_patch_width = post_patch_width |
| | self.post_time_compression_frames = post_time_compression_frames |
| | self.patch_size = patch_size |
| | self.patch_size_t = patch_size_t |
| | self.embed_dim = embed_dim |
| | self.sample_height = sample_height |
| | self.sample_width = sample_width |
| | self.sample_frames = sample_frames |
| | self.temporal_compression_ratio = temporal_compression_ratio |
| | self.max_text_seq_length = max_text_seq_length |
| | self.spatial_interpolation_scale = spatial_interpolation_scale |
| | self.temporal_interpolation_scale = temporal_interpolation_scale |
| | self.use_positional_embeddings = use_positional_embeddings |
| | self.use_learned_positional_embeddings = use_learned_positional_embeddings |
| | |
| | if patch_size_t is None: |
| | |
| | self.proj = nn.Conv2d( |
| | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias |
| | ) |
| | else: |
| | |
| | self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim) |
| |
|
| | self.text_proj = nn.Linear(text_embed_dim, embed_dim) |
| |
|
| | if use_positional_embeddings or use_learned_positional_embeddings: |
| | persistent = use_learned_positional_embeddings |
| | pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) |
| | self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) |
| |
|
| | def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: |
| | post_patch_height = sample_height // self.patch_size |
| | post_patch_width = sample_width // self.patch_size |
| | post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 |
| | num_patches = post_patch_height * post_patch_width * post_time_compression_frames |
| |
|
| | pos_embedding = get_3d_sincos_pos_embed( |
| | self.embed_dim, |
| | (post_patch_width, post_patch_height), |
| | post_time_compression_frames, |
| | self.spatial_interpolation_scale, |
| | self.temporal_interpolation_scale, |
| | ) |
| | pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) |
| | joint_pos_embedding = torch.zeros( |
| | 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False |
| | ) |
| | joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding) |
| |
|
| | return joint_pos_embedding |
| |
|
| | def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor, flow_embeds: Optional[torch.Tensor] = None): |
| | r""" |
| | Args: |
| | text_embeds (`torch.Tensor`): |
| | Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). |
| | image_embeds (`torch.Tensor`): |
| | Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). |
| | """ |
| | text_embeds = self.text_proj(text_embeds) |
| |
|
| | text_batch_size, text_seq_length, text_channels = text_embeds.shape |
| | batch_size, num_frames, channels, height, width = image_embeds.shape |
| |
|
| | if self.patch_size_t is None: |
| | image_embeds = image_embeds.reshape(-1, channels, height, width) |
| | image_embeds = self.proj(image_embeds) |
| | image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) |
| | image_embeds = image_embeds.flatten(3).transpose(2, 3) |
| | image_embeds = image_embeds.flatten(1, 2) |
| | else: |
| | p = self.patch_size |
| | p_t = self.patch_size_t |
| |
|
| | image_embeds = image_embeds.permute(0, 1, 3, 4, 2) |
| | |
| | image_embeds = image_embeds.reshape( |
| | batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels |
| | ) |
| | |
| | image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) |
| |
|
| | image_embeds = self.proj(image_embeds) |
| |
|
| | embeds = torch.cat( |
| | [text_embeds, image_embeds], dim=1 |
| | ).contiguous() |
| |
|
| | if self.use_positional_embeddings or self.use_learned_positional_embeddings: |
| | seq_length = height * width * num_frames // (self.patch_size**2) |
| | |
| | pos_embeds = self.pos_embedding |
| | emb_size = embeds.size()[-1] |
| | pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size) |
| | pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3]) |
| | pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False) |
| | pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size) |
| | pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1) |
| | pos_embeds = pos_embeds[:, : text_seq_length + seq_length] |
| | embeds = embeds + pos_embeds |
| |
|
| | if flow_embeds is not None: |
| | |
| | if self.patch_size_t is not None: |
| | _, _, flow_channels, _, _ = flow_embeds.shape |
| |
|
| | flow_embeds = flow_embeds.permute(0, 1, 3, 4, 2) |
| | |
| | flow_embeds = flow_embeds.reshape( |
| | batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, flow_channels |
| | ) |
| | |
| | flow_embeds = flow_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7) |
| |
|
| | return embeds, flow_embeds |
| |
|
| |
|
| | @maybe_allow_in_graph |
| | class CogVideoXBlock(nn.Module): |
| | r""" |
| | Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. |
| | |
| | Parameters: |
| | dim (`int`): |
| | The number of channels in the input and output. |
| | num_attention_heads (`int`): |
| | The number of heads to use for multi-head attention. |
| | attention_head_dim (`int`): |
| | The number of channels in each head. |
| | time_embed_dim (`int`): |
| | The number of channels in timestep embedding. |
| | dropout (`float`, defaults to `0.0`): |
| | The dropout probability to use. |
| | activation_fn (`str`, defaults to `"gelu-approximate"`): |
| | Activation function to be used in feed-forward. |
| | attention_bias (`bool`, defaults to `False`): |
| | Whether or not to use bias in attention projection layers. |
| | qk_norm (`bool`, defaults to `True`): |
| | Whether or not to use normalization after query and key projections in Attention. |
| | norm_elementwise_affine (`bool`, defaults to `True`): |
| | Whether to use learnable elementwise affine parameters for normalization. |
| | norm_eps (`float`, defaults to `1e-5`): |
| | Epsilon value for normalization layers. |
| | final_dropout (`bool` defaults to `False`): |
| | Whether to apply a final dropout after the last feed-forward layer. |
| | ff_inner_dim (`int`, *optional*, defaults to `None`): |
| | Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. |
| | ff_bias (`bool`, defaults to `True`): |
| | Whether or not to use bias in Feed-forward layer. |
| | attention_out_bias (`bool`, defaults to `True`): |
| | Whether or not to use bias in Attention output projection layer. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | block_idx: int, |
| | dim: int, |
| | num_attention_heads: int, |
| | attention_head_dim: int, |
| | time_embed_dim: int, |
| | block_interval: int = 1, |
| | flow_in_dim: int = 128, |
| | out_dim: int = 3072, |
| | dropout: float = 0.0, |
| | activation_fn: str = "gelu-approximate", |
| | attention_bias: bool = False, |
| | qk_norm: bool = True, |
| | norm_elementwise_affine: bool = True, |
| | norm_eps: float = 1e-5, |
| | final_dropout: bool = True, |
| | ff_inner_dim: Optional[int] = None, |
| | ff_bias: bool = True, |
| | attention_out_bias: bool = True, |
| | finetune_init: bool = False, |
| | ): |
| | super().__init__() |
| |
|
| |
|
| | |
| | self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
| |
|
| | self.attn1 = Attention( |
| | query_dim=dim, |
| | dim_head=attention_head_dim, |
| | heads=num_attention_heads, |
| | qk_norm="layer_norm" if qk_norm else None, |
| | eps=1e-6, |
| | bias=attention_bias, |
| | out_bias=attention_out_bias, |
| | processor=CogVideoXAttnProcessor2_0(), |
| | ) |
| | |
| | if not finetune_init and (block_idx%block_interval==0): |
| | self.flow_spatial = nn.Conv2d(flow_in_dim, out_dim // 4, 3, padding=1) |
| | self.flow_temporal = zero_module( |
| | nn.Conv1d( |
| | out_dim // 4, |
| | out_dim, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | padding_mode="replicate", |
| | ) |
| | ) |
| | self.flow_cond_norm = FloatGroupNorm(32, out_dim) |
| |
|
| | |
| | self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
| |
|
| | self.ff = FeedForward( |
| | dim, |
| | dropout=dropout, |
| | activation_fn=activation_fn, |
| | final_dropout=final_dropout, |
| | inner_dim=ff_inner_dim, |
| | bias=ff_bias, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | encoder_hidden_states: torch.Tensor, |
| | flow_states: torch.Tensor, |
| | temb: torch.Tensor, |
| | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| | ) -> torch.Tensor: |
| | text_seq_length = encoder_hidden_states.size(1) |
| |
|
| | |
| | norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( |
| | hidden_states, encoder_hidden_states, temb |
| | ) |
| |
|
| | |
| | attn_hidden_states, attn_encoder_hidden_states = self.attn1( |
| | hidden_states=norm_hidden_states, |
| | encoder_hidden_states=norm_encoder_hidden_states, |
| | image_rotary_emb=image_rotary_emb, |
| | ) |
| | |
| | hidden_states = hidden_states + gate_msa * attn_hidden_states |
| | encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states |
| |
|
| | if hasattr(self, "flow_spatial") and flow_states is not None: |
| | |
| | bz, f, h, w, c = flow_states.shape |
| |
|
| | flow_states = rearrange(flow_states, "bz f h w c -> (bz f) c h w") |
| | flow_states = self.flow_spatial(flow_states) |
| |
|
| | flow_states = rearrange(flow_states, "(bz f) c h w -> (bz h w) c f", f=f) |
| | flow_states = self.flow_temporal(flow_states) |
| | flow_states = rearrange(flow_states, "(bz h w) c f -> bz (f h w) c", f=f, h=h, w=w) |
| | |
| | norm_flow_states = self.flow_cond_norm(rearrange(flow_states, "bz (f h w) c -> (bz f) c h w", h=h, w=w)) |
| | norm_flow_states = rearrange(norm_flow_states, "(bz f) c h w -> bz (f h w) c", f=f, h=h, w=w) |
| |
|
| | hidden_states = hidden_states + norm_flow_states * flow_states |
| |
|
| | |
| | norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( |
| | hidden_states, encoder_hidden_states, temb |
| | ) |
| |
|
| | |
| | norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) |
| | ff_output = self.ff(norm_hidden_states) |
| |
|
| | hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] |
| | encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] |
| |
|
| | return hidden_states, encoder_hidden_states |
| |
|
| |
|
| | class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): |
| | """ |
| | A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). |
| | |
| | Parameters: |
| | num_attention_heads (`int`, defaults to `30`): |
| | The number of heads to use for multi-head attention. |
| | attention_head_dim (`int`, defaults to `64`): |
| | The number of channels in each head. |
| | in_channels (`int`, defaults to `16`): |
| | The number of channels in the input. |
| | out_channels (`int`, *optional*, defaults to `16`): |
| | The number of channels in the output. |
| | flip_sin_to_cos (`bool`, defaults to `True`): |
| | Whether to flip the sin to cos in the time embedding. |
| | time_embed_dim (`int`, defaults to `512`): |
| | Output dimension of timestep embeddings. |
| | text_embed_dim (`int`, defaults to `4096`): |
| | Input dimension of text embeddings from the text encoder. |
| | num_layers (`int`, defaults to `30`): |
| | The number of layers of Transformer blocks to use. |
| | dropout (`float`, defaults to `0.0`): |
| | The dropout probability to use. |
| | attention_bias (`bool`, defaults to `True`): |
| | Whether or not to use bias in the attention projection layers. |
| | sample_width (`int`, defaults to `90`): |
| | The width of the input latents. |
| | sample_height (`int`, defaults to `60`): |
| | The height of the input latents. |
| | sample_frames (`int`, defaults to `49`): |
| | The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 |
| | instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, |
| | but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with |
| | K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). |
| | patch_size (`int`, defaults to `2`): |
| | The size of the patches to use in the patch embedding layer. |
| | temporal_compression_ratio (`int`, defaults to `4`): |
| | The compression ratio across the temporal dimension. See documentation for `sample_frames`. |
| | max_text_seq_length (`int`, defaults to `226`): |
| | The maximum sequence length of the input text embeddings. |
| | activation_fn (`str`, defaults to `"gelu-approximate"`): |
| | Activation function to use in feed-forward. |
| | timestep_activation_fn (`str`, defaults to `"silu"`): |
| | Activation function to use when generating the timestep embeddings. |
| | norm_elementwise_affine (`bool`, defaults to `True`): |
| | Whether or not to use elementwise affine in normalization layers. |
| | norm_eps (`float`, defaults to `1e-5`): |
| | The epsilon value to use in normalization layers. |
| | spatial_interpolation_scale (`float`, defaults to `1.875`): |
| | Scaling factor to apply in 3D positional embeddings across spatial dimensions. |
| | temporal_interpolation_scale (`float`, defaults to `1.0`): |
| | Scaling factor to apply in 3D positional embeddings across temporal dimensions. |
| | """ |
| |
|
| | _supports_gradient_checkpointing = True |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | num_attention_heads: int = 30, |
| | attention_head_dim: int = 64, |
| | in_channels: int = 16, |
| | out_channels: Optional[int] = 16, |
| | flip_sin_to_cos: bool = True, |
| | freq_shift: int = 0, |
| | time_embed_dim: int = 512, |
| | text_embed_dim: int = 4096, |
| | num_layers: int = 30, |
| | dropout: float = 0.0, |
| | attention_bias: bool = True, |
| | sample_width: int = 90, |
| | sample_height: int = 60, |
| | sample_frames: int = 49, |
| | patch_size: int = 2, |
| | patch_size_t: Optional[int] = None, |
| | temporal_compression_ratio: int = 4, |
| | max_text_seq_length: int = 226, |
| | activation_fn: str = "gelu-approximate", |
| | timestep_activation_fn: str = "silu", |
| | norm_elementwise_affine: bool = True, |
| | norm_eps: float = 1e-5, |
| | spatial_interpolation_scale: float = 1.875, |
| | temporal_interpolation_scale: float = 1.0, |
| | use_rotary_positional_embeddings: bool = False, |
| | use_learned_positional_embeddings: bool = False, |
| | patch_bias: bool = True, |
| | add_noise_in_inpaint_model: bool = False, |
| | finetune_init: bool = False, |
| | ): |
| | super().__init__() |
| | inner_dim = num_attention_heads * attention_head_dim |
| | self.patch_size_t = patch_size_t |
| | if not use_rotary_positional_embeddings and use_learned_positional_embeddings: |
| | raise ValueError( |
| | "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " |
| | "embeddings. If you're using a custom model and/or believe this should be supported, please open an " |
| | "issue at https://github.com/huggingface/diffusers/issues." |
| | ) |
| |
|
| | |
| | self.patch_embed = CogVideoXPatchEmbed( |
| | patch_size=patch_size, |
| | patch_size_t=patch_size_t, |
| | in_channels=in_channels, |
| | embed_dim=inner_dim, |
| | text_embed_dim=text_embed_dim, |
| | bias=patch_bias, |
| | sample_width=sample_width, |
| | sample_height=sample_height, |
| | sample_frames=sample_frames, |
| | temporal_compression_ratio=temporal_compression_ratio, |
| | max_text_seq_length=max_text_seq_length, |
| | spatial_interpolation_scale=spatial_interpolation_scale, |
| | temporal_interpolation_scale=temporal_interpolation_scale, |
| | use_positional_embeddings=not use_rotary_positional_embeddings, |
| | use_learned_positional_embeddings=use_learned_positional_embeddings, |
| | ) |
| | self.embedding_dropout = nn.Dropout(dropout) |
| |
|
| | |
| | self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) |
| | self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) |
| |
|
| | |
| | self.transformer_blocks = nn.ModuleList( |
| | [ |
| | CogVideoXBlock( |
| | block_idx=idx, |
| | dim=inner_dim, |
| | num_attention_heads=num_attention_heads, |
| | attention_head_dim=attention_head_dim, |
| | time_embed_dim=time_embed_dim, |
| | dropout=dropout, |
| | activation_fn=activation_fn, |
| | attention_bias=attention_bias, |
| | norm_elementwise_affine=norm_elementwise_affine, |
| | norm_eps=norm_eps, |
| | finetune_init=finetune_init, |
| | ) |
| | for idx in range(num_layers) |
| | ] |
| | ) |
| | self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) |
| |
|
| | |
| | self.norm_out = AdaLayerNorm( |
| | embedding_dim=time_embed_dim, |
| | output_dim=2 * inner_dim, |
| | norm_elementwise_affine=norm_elementwise_affine, |
| | norm_eps=norm_eps, |
| | chunk_dim=1, |
| | ) |
| |
|
| | if patch_size_t is None: |
| | |
| | output_dim = patch_size * patch_size * out_channels |
| | else: |
| | |
| | output_dim = patch_size * patch_size * patch_size_t * out_channels |
| |
|
| | self.proj_out = nn.Linear(inner_dim, output_dim) |
| |
|
| | self.gradient_checkpointing = False |
| |
|
| | def _set_gradient_checkpointing(self, module, value=False): |
| | self.gradient_checkpointing = value |
| |
|
| | @property |
| | |
| | def attn_processors(self) -> Dict[str, AttentionProcessor]: |
| | r""" |
| | Returns: |
| | `dict` of attention processors: A dictionary containing all attention processors used in the model with |
| | indexed by its weight name. |
| | """ |
| | |
| | processors = {} |
| |
|
| | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
| | if hasattr(module, "get_processor"): |
| | processors[f"{name}.processor"] = module.get_processor() |
| |
|
| | for sub_name, child in module.named_children(): |
| | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
| |
|
| | return processors |
| |
|
| | for name, module in self.named_children(): |
| | fn_recursive_add_processors(name, module, processors) |
| |
|
| | return processors |
| |
|
| | |
| | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
| | r""" |
| | Sets the attention processor to use to compute attention. |
| | |
| | Parameters: |
| | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
| | The instantiated processor class or a dictionary of processor classes that will be set as the processor |
| | for **all** `Attention` layers. |
| | |
| | If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
| | processor. This is strongly recommended when setting trainable attention processors. |
| | |
| | """ |
| | count = len(self.attn_processors.keys()) |
| |
|
| | if isinstance(processor, dict) and len(processor) != count: |
| | raise ValueError( |
| | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
| | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
| | ) |
| |
|
| | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
| | if hasattr(module, "set_processor"): |
| | if not isinstance(processor, dict): |
| | module.set_processor(processor) |
| | else: |
| | module.set_processor(processor.pop(f"{name}.processor")) |
| |
|
| | for sub_name, child in module.named_children(): |
| | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
| |
|
| | for name, module in self.named_children(): |
| | fn_recursive_attn_processor(name, module, processor) |
| |
|
| | |
| | def fuse_qkv_projections(self): |
| | """ |
| | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) |
| | are fused. For cross-attention modules, key and value projection matrices are fused. |
| | |
| | <Tip warning={true}> |
| | |
| | This API is 🧪 experimental. |
| | |
| | </Tip> |
| | """ |
| | self.original_attn_processors = None |
| |
|
| | for _, attn_processor in self.attn_processors.items(): |
| | if "Added" in str(attn_processor.__class__.__name__): |
| | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") |
| |
|
| | self.original_attn_processors = self.attn_processors |
| |
|
| | for module in self.modules(): |
| | if isinstance(module, Attention): |
| | module.fuse_projections(fuse=True) |
| |
|
| | self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) |
| |
|
| | |
| | def unfuse_qkv_projections(self): |
| | """Disables the fused QKV projection if enabled. |
| | |
| | <Tip warning={true}> |
| | |
| | This API is 🧪 experimental. |
| | |
| | </Tip> |
| | |
| | """ |
| | if self.original_attn_processors is not None: |
| | self.set_attn_processor(self.original_attn_processors) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | encoder_hidden_states: torch.Tensor, |
| | timestep: Union[int, float, torch.LongTensor], |
| | timestep_cond: Optional[torch.Tensor] = None, |
| | inpaint_latents: Optional[torch.Tensor] = None, |
| | flow_latents: Optional[torch.Tensor] = None, |
| | control_latents: Optional[torch.Tensor] = None, |
| | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| | return_dict: bool = True, |
| | ): |
| |
|
| | batch_size, num_frames, channels, height, width = hidden_states.shape |
| | if num_frames == 1 and self.patch_size_t is not None: |
| | hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1) |
| | if inpaint_latents is not None: |
| | inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1) |
| | if control_latents is not None: |
| | control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1) |
| | local_num_frames = num_frames + 1 |
| | else: |
| | local_num_frames = num_frames |
| |
|
| | |
| | timesteps = timestep |
| | t_emb = self.time_proj(timesteps) |
| |
|
| | |
| | |
| | |
| | t_emb = t_emb.to(dtype=hidden_states.dtype) |
| | emb = self.time_embedding(t_emb, timestep_cond) |
| |
|
| | |
| | if inpaint_latents is not None: |
| | hidden_states = torch.concat([hidden_states, inpaint_latents], 2) |
| | if control_latents is not None: |
| | hidden_states = torch.concat([hidden_states, control_latents], 2) |
| |
|
| | hidden_states, flow_states = self.patch_embed(encoder_hidden_states, hidden_states, flow_latents) |
| | hidden_states = self.embedding_dropout(hidden_states) |
| |
|
| | text_seq_length = encoder_hidden_states.shape[1] |
| | encoder_hidden_states = hidden_states[:, :text_seq_length] |
| | hidden_states = hidden_states[:, text_seq_length:] |
| | |
| | |
| | for i, block in enumerate(self.transformer_blocks): |
| | if self.training and self.gradient_checkpointing: |
| |
|
| | def create_custom_forward(module): |
| | def custom_forward(*inputs): |
| | return module(*inputs) |
| |
|
| | return custom_forward |
| |
|
| | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
| | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( |
| | create_custom_forward(block), |
| | hidden_states, |
| | encoder_hidden_states, |
| | flow_states, |
| | emb, |
| | image_rotary_emb, |
| | **ckpt_kwargs, |
| | ) |
| | else: |
| | hidden_states, encoder_hidden_states = block( |
| | hidden_states=hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | flow_states=flow_states, |
| | temb=emb, |
| | image_rotary_emb=image_rotary_emb, |
| | ) |
| |
|
| | if not self.config.use_rotary_positional_embeddings: |
| | |
| | hidden_states = self.norm_final(hidden_states) |
| | else: |
| | |
| | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
| | hidden_states = self.norm_final(hidden_states) |
| | hidden_states = hidden_states[:, text_seq_length:] |
| |
|
| | |
| | hidden_states = self.norm_out(hidden_states, temb=emb) |
| | hidden_states = self.proj_out(hidden_states) |
| |
|
| | |
| | p = self.config.patch_size |
| | p_t = self.config.patch_size_t |
| |
|
| | if p_t is None: |
| | output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p) |
| | output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) |
| | else: |
| | output = hidden_states.reshape( |
| | batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p |
| | ) |
| | output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) |
| | |
| | if num_frames == 1: |
| | output = output[:, :num_frames, :] |
| |
|
| | if not return_dict: |
| | return (output,) |
| | return Transformer2DModelOutput(sample=output) |
| |
|
| | @classmethod |
| | def from_pretrained_2d( |
| | cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, |
| | finetune_init=False, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 |
| | ): |
| | if subfolder is not None: |
| | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) |
| | print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") |
| |
|
| | config_file = os.path.join(pretrained_model_path, 'config.json') |
| | if not os.path.isfile(config_file): |
| | raise RuntimeError(f"{config_file} does not exist") |
| | with open(config_file, "r") as f: |
| | config = json.load(f) |
| | config['finetune_init'] = finetune_init |
| |
|
| | from diffusers.utils import WEIGHTS_NAME |
| | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) |
| | model_file_safetensors = model_file.replace(".bin", ".safetensors") |
| |
|
| | if low_cpu_mem_usage: |
| | try: |
| | import re |
| | from diffusers.utils import is_accelerate_available |
| | from diffusers.models.modeling_utils import load_model_dict_into_meta |
| | if is_accelerate_available(): |
| | import accelerate |
| | |
| | |
| | with accelerate.init_empty_weights(): |
| | model = cls.from_config(config, **transformer_additional_kwargs) |
| |
|
| | param_device = "cpu" |
| | if os.path.exists(model_file): |
| | state_dict = torch.load(model_file, map_location="cpu") |
| | elif os.path.exists(model_file_safetensors): |
| | from safetensors.torch import load_file, safe_open |
| | state_dict = load_file(model_file_safetensors) |
| | else: |
| | from safetensors.torch import load_file, safe_open |
| | model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) |
| | state_dict = {} |
| | for _model_file_safetensors in model_files_safetensors: |
| | _state_dict = load_file(_model_file_safetensors) |
| | for key in _state_dict: |
| | state_dict[key] = _state_dict[key] |
| | model._convert_deprecated_attention_blocks(state_dict) |
| | |
| | missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) |
| | if len(missing_keys) > 0: |
| | raise ValueError( |
| | f"Cannot load {cls} from {pretrained_model_path} because the following keys are" |
| | f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" |
| | " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" |
| | " those weights or else make sure your checkpoint file is correct." |
| | ) |
| |
|
| | unexpected_keys = load_model_dict_into_meta( |
| | model, |
| | state_dict, |
| | device=param_device, |
| | dtype=torch_dtype, |
| | model_name_or_path=pretrained_model_path, |
| | ) |
| |
|
| | if cls._keys_to_ignore_on_load_unexpected is not None: |
| | for pat in cls._keys_to_ignore_on_load_unexpected: |
| | unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] |
| |
|
| | if len(unexpected_keys) > 0: |
| | print( |
| | f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" |
| | ) |
| | return model |
| | except Exception as e: |
| | print( |
| | f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." |
| | ) |
| |
|
| | model = cls.from_config(config, **transformer_additional_kwargs) |
| | if os.path.exists(model_file): |
| | state_dict = torch.load(model_file, map_location="cpu") |
| | elif os.path.exists(model_file_safetensors): |
| | from safetensors.torch import load_file, safe_open |
| | state_dict = load_file(model_file_safetensors) |
| | else: |
| | from safetensors.torch import load_file, safe_open |
| | model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) |
| | state_dict = {} |
| | for _model_file_safetensors in model_files_safetensors: |
| | _state_dict = load_file(_model_file_safetensors) |
| | for key in _state_dict: |
| | state_dict[key] = _state_dict[key] |
| | |
| | if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size(): |
| | new_shape = model.state_dict()['patch_embed.proj.weight'].size() |
| | if len(new_shape) == 5: |
| | state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone() |
| | state_dict['patch_embed.proj.weight'][:, :, :-1] = 0 |
| | elif len(new_shape) == 2: |
| | if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]: |
| | model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight'] |
| | model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0 |
| | state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] |
| | else: |
| | model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]] |
| | state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] |
| | else: |
| | if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]: |
| | model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight'] |
| | model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0 |
| | state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] |
| | else: |
| | model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :] |
| | state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight'] |
| |
|
| | tmp_state_dict = {} |
| | for key in state_dict: |
| | if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size(): |
| | tmp_state_dict[key] = state_dict[key] |
| | else: |
| | import ipdb; ipdb.set_trace() |
| | print(key, "Size don't match, skip") |
| | |
| | state_dict = tmp_state_dict |
| |
|
| | m, u = model.load_state_dict(state_dict, strict=False) |
| | print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") |
| | print(m) |
| | |
| | params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()] |
| | print(f"### All Parameters: {sum(params) / 1e6} M") |
| | |
| | model = model.to(torch_dtype) |
| | return model |