Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from typing import Tuple, Union, Optional | |
| from diffusers.models.embeddings import get_3d_sincos_pos_embed, get_1d_rotary_pos_embed | |
| class CogVideoXPatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| patch_size: int = 2, | |
| patch_size_t: Optional[int] = None, | |
| in_channels: int = 16, | |
| embed_dim: int = 1920, | |
| text_embed_dim: int = 4096, | |
| bias: bool = True, | |
| sample_width: int = 90, | |
| sample_height: int = 60, | |
| sample_frames: int = 49, | |
| temporal_compression_ratio: int = 4, | |
| max_text_seq_length: int = 226, | |
| spatial_interpolation_scale: float = 1.875, | |
| temporal_interpolation_scale: float = 1.0, | |
| use_positional_embeddings: bool = True, | |
| use_learned_positional_embeddings: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.patch_size_t = patch_size_t | |
| self.embed_dim = embed_dim | |
| self.sample_height = sample_height | |
| self.sample_width = sample_width | |
| self.sample_frames = sample_frames | |
| self.temporal_compression_ratio = temporal_compression_ratio | |
| self.max_text_seq_length = max_text_seq_length | |
| self.spatial_interpolation_scale = spatial_interpolation_scale | |
| self.temporal_interpolation_scale = temporal_interpolation_scale | |
| self.use_positional_embeddings = use_positional_embeddings | |
| self.use_learned_positional_embeddings = use_learned_positional_embeddings | |
| if patch_size_t is None: | |
| # CogVideoX 1.0 checkpoints | |
| self.proj = nn.Conv2d( | |
| in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | |
| ) | |
| else: | |
| # CogVideoX 1.5 checkpoints | |
| self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim) | |
| self.text_proj = nn.Linear(text_embed_dim, embed_dim) | |
| if use_positional_embeddings or use_learned_positional_embeddings: | |
| persistent = use_learned_positional_embeddings | |
| pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) | |
| self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) | |
| def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: | |
| post_patch_height = sample_height // self.patch_size | |
| post_patch_width = sample_width // self.patch_size | |
| post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 | |
| num_patches = post_patch_height * post_patch_width * post_time_compression_frames | |
| pos_embedding = get_3d_sincos_pos_embed( | |
| self.embed_dim, | |
| (post_patch_width, post_patch_height), | |
| post_time_compression_frames, | |
| self.spatial_interpolation_scale, | |
| self.temporal_interpolation_scale, | |
| ) | |
| pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) | |
| joint_pos_embedding = torch.zeros( | |
| 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False | |
| ) | |
| joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding) | |
| return joint_pos_embedding | |
| def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): | |
| r""" | |
| Args: | |
| text_embeds (`torch.Tensor`): | |
| Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). | |
| image_embeds (`torch.Tensor`): | |
| Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). | |
| """ | |
| text_embeds = self.text_proj(text_embeds) | |
| batch_size, num_frames, channels, height, width = image_embeds.shape | |
| if self.patch_size_t is None: | |
| image_embeds = image_embeds.reshape(-1, channels, height, width) | |
| image_embeds = self.proj(image_embeds) | |
| image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) | |
| image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] | |
| image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] | |
| else: | |
| p = self.patch_size | |
| p_t = self.patch_size_t | |
| image_embeds = image_embeds.permute(0, 1, 3, 4, 2) | |
| image_embeds = image_embeds.reshape( | |
| batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels | |
| ) | |
| image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) | |
| image_embeds = self.proj(image_embeds) | |
| embeds = torch.cat( | |
| [text_embeds, image_embeds], dim=1 | |
| ).contiguous() # [batch, seq_length + num_frames x height x width, channels] | |
| if self.use_positional_embeddings or self.use_learned_positional_embeddings: | |
| if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): | |
| raise ValueError( | |
| "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'." | |
| "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." | |
| ) | |
| pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 | |
| if ( | |
| self.sample_height != height | |
| or self.sample_width != width | |
| or self.sample_frames != pre_time_compression_frames | |
| ): | |
| pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) | |
| pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) | |
| else: | |
| pos_embedding = self.pos_embedding | |
| embeds = embeds + pos_embedding | |
| return embeds | |
| def get_3d_rotary_pos_embed( | |
| embed_dim, | |
| crops_coords, | |
| grid_size, | |
| temporal_size, | |
| theta: int = 10000, | |
| use_real: bool = True, | |
| grid_type: str = "linspace", | |
| max_size: Optional[Tuple[int, int]] = None, | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| """ | |
| RoPE for video tokens with 3D structure. | |
| Args: | |
| embed_dim: (`int`): | |
| The embedding dimension size, corresponding to hidden_size_head. | |
| crops_coords (`Tuple[int]`): | |
| The top-left and bottom-right coordinates of the crop. | |
| grid_size (`Tuple[int]`): | |
| The grid size of the spatial positional embedding (height, width). | |
| temporal_size (`int`): | |
| The size of the temporal dimension. | |
| theta (`float`): | |
| Scaling factor for frequency computation. | |
| grid_type (`str`): | |
| Whether to use "linspace" or "slice" to compute grids. | |
| Returns: | |
| `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. | |
| """ | |
| if use_real is not True: | |
| raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") | |
| if grid_type == "linspace": | |
| start, stop = crops_coords | |
| grid_size_h, grid_size_w = grid_size | |
| grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) | |
| grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) | |
| grid_t = np.arange(temporal_size, dtype=np.float32) | |
| grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) | |
| elif grid_type == "slice": | |
| max_h, max_w = max_size | |
| grid_size_h, grid_size_w = grid_size | |
| grid_h = np.arange(max_h, dtype=np.float32) | |
| grid_w = np.arange(max_w, dtype=np.float32) | |
| grid_t = np.arange(temporal_size, dtype=np.float32) | |
| else: | |
| raise ValueError("Invalid value passed for `grid_type`.") | |
| # Compute dimensions for each axis | |
| dim_t = embed_dim // 4 | |
| dim_h = embed_dim // 8 * 3 | |
| dim_w = embed_dim // 8 * 3 | |
| # Temporal frequencies | |
| freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) | |
| # Spatial frequencies for height and width | |
| freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) | |
| freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) | |
| # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor | |
| def combine_time_height_width(freqs_t, freqs_h, freqs_w): | |
| freqs_t = freqs_t[:, None, None, :].expand( | |
| -1, grid_size_h, grid_size_w, -1 | |
| ) # temporal_size, grid_size_h, grid_size_w, dim_t | |
| freqs_h = freqs_h[None, :, None, :].expand( | |
| temporal_size, -1, grid_size_w, -1 | |
| ) # temporal_size, grid_size_h, grid_size_2, dim_h | |
| freqs_w = freqs_w[None, None, :, :].expand( | |
| temporal_size, grid_size_h, -1, -1 | |
| ) # temporal_size, grid_size_h, grid_size_2, dim_w | |
| freqs = torch.cat( | |
| [freqs_t, freqs_h, freqs_w], dim=-1 | |
| ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) | |
| freqs = freqs.view( | |
| temporal_size * grid_size_h * grid_size_w, -1 | |
| ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) | |
| return freqs | |
| t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t | |
| h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h | |
| w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w | |
| if grid_type == "slice": | |
| t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] | |
| h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] | |
| w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] | |
| cos = combine_time_height_width(t_cos, h_cos, w_cos) | |
| sin = combine_time_height_width(t_sin, h_sin, w_sin) | |
| return cos, sin |