Spaces:
Runtime error
Runtime error
| # Adapted from Open-Sora-Plan | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -------------------------------------------------------- | |
| # References: | |
| # Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan | |
| # -------------------------------------------------------- | |
| import torch | |
| import collections | |
| from diffusers.models.embeddings import TimestepEmbedding, Timesteps | |
| from einops import rearrange | |
| from torch import nn | |
| from diffusers.utils import logging | |
| logger = logging.get_logger(__name__) | |
| class CombinedTimestepSizeEmbeddings(nn.Module): | |
| """ | |
| For PixArt-Alpha. | |
| Reference: | |
| https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 | |
| """ | |
| def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): | |
| super().__init__() | |
| self.outdim = size_emb_dim | |
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
| self.use_additional_conditions = use_additional_conditions | |
| if use_additional_conditions: | |
| self.use_additional_conditions = True | |
| self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
| self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
| self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
| def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): | |
| if size.ndim == 1: | |
| size = size[:, None] | |
| if size.shape[0] != batch_size: | |
| size = size.repeat(batch_size // size.shape[0], 1) | |
| if size.shape[0] != batch_size: | |
| raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") | |
| current_batch_size, dims = size.shape[0], size.shape[1] | |
| size = size.reshape(-1) | |
| size_freq = self.additional_condition_proj(size).to(size.dtype) | |
| size_emb = embedder(size_freq) | |
| size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) | |
| return size_emb | |
| def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): | |
| timesteps_proj = self.time_proj(timestep) | |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | |
| if self.use_additional_conditions: | |
| resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) | |
| aspect_ratio = self.apply_condition( | |
| aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder | |
| ) | |
| conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) | |
| else: | |
| conditioning = timesteps_emb | |
| return conditioning | |
| class PatchEmbed2D(nn.Module): | |
| """2D Image to Patch Embedding""" | |
| def __init__( | |
| self, | |
| num_frames=1, | |
| height=224, | |
| width=224, | |
| patch_size_t=1, | |
| patch_size=16, | |
| in_channels=3, | |
| embed_dim=768, | |
| layer_norm=False, | |
| flatten=True, | |
| bias=True, | |
| interpolation_scale=(1, 1), | |
| interpolation_scale_t=1, | |
| use_abs_pos=False, | |
| ): | |
| super().__init__() | |
| self.use_abs_pos = use_abs_pos | |
| self.flatten = flatten | |
| self.layer_norm = layer_norm | |
| self.proj = nn.Conv2d( | |
| in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias | |
| ) | |
| if layer_norm: | |
| self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
| else: | |
| self.norm = None | |
| self.patch_size_t = patch_size_t | |
| self.patch_size = patch_size | |
| def forward(self, latent): | |
| b, _, _, _, _ = latent.shape | |
| video_latent = None | |
| latent = rearrange(latent, 'b c t h w -> (b t) c h w') | |
| latent = self.proj(latent) | |
| if self.flatten: | |
| latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C | |
| if self.layer_norm: | |
| latent = self.norm(latent) | |
| latent = rearrange(latent, '(b t) n c -> b (t n) c', b=b) | |
| video_latent = latent | |
| return video_latent | |