import torch from typing import Optional from diffusers.models.attention import Attention,FeedForward from torch import nn from typing import Union, Tuple import torch.nn.functional as F from einops import rearrange from torch.utils.checkpoint import checkpoint def video_to_image(func): def wrapper(self, x, *args, **kwargs): if x.dim() == 5: t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") x = func(self, x, *args, **kwargs) x = rearrange(x, "(b t) c h w -> b c t h w", t=t) return x return wrapper def nonlinearity(x): return x * torch.sigmoid(x) def cast_tuple(t, length=1): return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length) class Block(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) class Conv2d(nn.Conv2d): def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int]] = 3, stride: Union[int, Tuple[int]] = 1, padding: Union[str, int, Tuple[int]] = 0, dilation: Union[int, Tuple[int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None, ) -> None: super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype, ) @video_to_image def forward(self, x): return super().forward(x) class CausalConv3d(nn.Module): def __init__( self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs ): super().__init__() self.kernel_size = cast_tuple(kernel_size, 3) self.time_kernel_size = self.kernel_size[0] self.chan_in = chan_in self.chan_out = chan_out stride = kwargs.pop("stride", 1) padding = kwargs.pop("padding", 0) padding = list(cast_tuple(padding, 3)) padding[0] = 0 stride = cast_tuple(stride, 3) self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding) self.pad = nn.ReplicationPad2d((0, 0, self.time_kernel_size - 1, 0)) self._init_weights(init_method) def _init_weights(self, init_method): ks = torch.tensor(self.kernel_size) if init_method == "avg": assert ( self.kernel_size[1] == 1 and self.kernel_size[2] == 1 ), "only support temporal up/down sample" assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out" weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)) eyes = torch.concat( [ torch.eye(self.chan_in).unsqueeze(-1) * 1/3, torch.eye(self.chan_in).unsqueeze(-1) * 1/3, torch.eye(self.chan_in).unsqueeze(-1) * 1/3, ], dim=-1, ) weight[:, :, :, 0, 0] = eyes self.conv.weight = nn.Parameter( weight, requires_grad=True, ) elif init_method == "zero": self.conv.weight = nn.Parameter( torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)), requires_grad=True, ) if self.conv.bias is not None: nn.init.constant_(self.conv.bias, 0) def forward(self, x): # 1 + 16 16 as video, 1 as image first_frame_pad = x[:, :, :1, :, :].repeat( (1, 1, self.time_kernel_size - 1, 1, 1) ) # b c t h w x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 return self.conv(x) class CausalConv3d_GC(CausalConv3d): def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int]], init_method="random", **kwargs): super().__init__(chan_in, chan_out, kernel_size, init_method, **kwargs) def forward(self, x): # 1 + 16 16 as video, 1 as image first_frame_pad = x[:, :, :1, :, :].repeat( (1, 1, self.time_kernel_size - 1, 1, 1) ) # b c t h w x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 return checkpoint(self.conv, x) class ActNorm(nn.Module): def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): assert affine super().__init__() self.logdet = logdet self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) self.allow_reverse_init = allow_reverse_init self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) def initialize(self, input): with torch.no_grad(): flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) mean = ( flatten.mean(1) .unsqueeze(1) .unsqueeze(2) .unsqueeze(3) .permute(1, 0, 2, 3) ) std = ( flatten.std(1) .unsqueeze(1) .unsqueeze(2) .unsqueeze(3) .permute(1, 0, 2, 3) ) self.loc.data.copy_(-mean) self.scale.data.copy_(1 / (std + 1e-6)) def forward(self, input, reverse=False): if reverse: return self.reverse(input) if len(input.shape) == 2: input = input[:,:,None,None] squeeze = True else: squeeze = False _, _, height, width = input.shape if self.training and self.initialized.item() == 0: self.initialize(input) self.initialized.fill_(1) h = self.scale * (input + self.loc) if squeeze: h = h.squeeze(-1).squeeze(-1) if self.logdet: log_abs = torch.log(torch.abs(self.scale)) logdet = height*width*torch.sum(log_abs) logdet = logdet * torch.ones(input.shape[0]).to(input) return h, logdet return h def reverse(self, output): if self.training and self.initialized.item() == 0: if not self.allow_reverse_init: raise RuntimeError( "Initializing ActNorm in reverse direction is " "disabled by default. Use allow_reverse_init=True to enable." ) else: self.initialize(output) self.initialized.fill_(1) if len(output.shape) == 2: output = output[:,:,None,None] squeeze = True else: squeeze = False h = output / self.scale - self.loc if squeeze: h = h.squeeze(-1).squeeze(-1) return h class PatchEmbed(nn.Module): def __init__( self, patch_size: int = 2, in_channels: int = 16, embed_dim: int = 1920, bias: bool = True, ) -> None: super().__init__() self.patch_size = patch_size self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) def forward(self, image_embeds: torch.Tensor): r""" Args: image_embeds (`torch.Tensor`): Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width) or (batch_size, channels, height, width) Returns: embeds (`torch.Tensor`): (batch_size,num_frames x height x width,embed_dim) or (batch_size,1 x height x width,embed_dim) """ if image_embeds.dim() == 5: batch, num_frames, channels, height, width = image_embeds.shape image_embeds = image_embeds.reshape(-1, channels, height, width) else: batch, channels, height, width = image_embeds.shape num_frames = 1 image_embeds = self.proj(image_embeds) image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] return image_embeds # [batch, num_frames x height x width, channels] class DiscTransformer(nn.Module): def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1. Self Attention self.norm1 = LayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2. Feed Forward self.norm2 = LayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, ) -> torch.Tensor: norm_hidden_states, gate_msa = self.norm1( hidden_states, temb ) # attention attn_output = self.attn1( hidden_states=norm_hidden_states, ) hidden_states = hidden_states + gate_msa * attn_output # norm & modulate norm_hidden_states, gate_ff = self.norm2( hidden_states, temb ) # feed-forward ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_ff * ff_output return hidden_states class LayerNormZero(nn.Module): def __init__( self, conditioning_dim: int, embedding_dim: int, elementwise_affine: bool = True, eps: float = 1e-5, bias: bool = True, ) -> None: super().__init__() self.embed_dim = embedding_dim self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_dim, 3 * embedding_dim, bias=bias) self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate = self.linear(self.silu(temb)).chunk(3, dim=1) if len(hidden_states.shape) == 3: hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] gate = gate[:, None, :] else: hidden_states = self.norm(hidden_states) * (1 + scale) + shift return hidden_states, gate class AdaLayerNorm(nn.Module): r""" Norm layer modified to incorporate timestep embeddings. Parameters: embedding_dim (`int`): The size of each embedding vector. num_embeddings (`int`, *optional*): The size of the embeddings dictionary. output_dim (`int`, *optional*): norm_elementwise_affine (`bool`, defaults to `False): norm_eps (`bool`, defaults to `False`): chunk_dim (`int`, defaults to `0`): """ def __init__( self, embedding_dim: int, num_embeddings: Optional[int] = None, output_dim: Optional[int] = None, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, chunk_dim: int = 0, ): super().__init__() self.chunk_dim = chunk_dim output_dim = output_dim or embedding_dim * 2 if num_embeddings is not None: self.emb = nn.Embedding(num_embeddings, embedding_dim) else: self.emb = None self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, output_dim) self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) def forward( self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None ) -> torch.Tensor: if self.emb is not None: temb = self.emb(timestep) temb = self.linear(self.silu(temb)) if self.chunk_dim == 1: # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the # other if-branch. This branch is specific to CogVideoX for now. shift, scale = temb.chunk(2, dim=1) shift = shift[:, None, :] scale = scale[:, None, :] else: scale, shift = temb.chunk(2, dim=0) x = self.norm(x) * (1 + scale) + shift return x