| |
|
|
| import torch |
| from torch import nn |
| from einops import rearrange, repeat |
| from einops.layers.torch import Rearrange |
| import logging |
| from typing import Callable, Optional, Tuple, List |
| import math |
| from torchvision import transforms |
| from ..core.attention import attention_forward |
| from ..core.gradient import gradient_checkpoint_forward |
|
|
|
|
| class VideoPositionEmb(nn.Module): |
| def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor: |
| """ |
| It delegates the embedding generation to generate_embeddings function. |
| """ |
| B_T_H_W_C = x_B_T_H_W_C.shape |
| embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype) |
|
|
| return embeddings |
|
|
| def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None): |
| raise NotImplementedError |
|
|
|
|
| def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor: |
| """ |
| Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted. |
| |
| Args: |
| x (torch.Tensor): The input tensor to normalize. |
| dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first. |
| eps (float, optional): A small constant to ensure numerical stability during division. |
| |
| Returns: |
| torch.Tensor: The normalized tensor. |
| """ |
| if dim is None: |
| dim = list(range(1, x.ndim)) |
| norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) |
| norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) |
| return x / norm.to(x.dtype) |
|
|
|
|
| class LearnablePosEmbAxis(VideoPositionEmb): |
| def __init__( |
| self, |
| *, |
| interpolation: str, |
| model_channels: int, |
| len_h: int, |
| len_w: int, |
| len_t: int, |
| device=None, |
| dtype=None, |
| **kwargs, |
| ): |
| """ |
| Args: |
| interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. |
| """ |
| del kwargs |
| super().__init__() |
| self.interpolation = interpolation |
| assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" |
|
|
| self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype)) |
| self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype)) |
| self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype)) |
|
|
| def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor: |
| B, T, H, W, _ = B_T_H_W_C |
| if self.interpolation == "crop": |
| emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype) |
| emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype) |
| emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype) |
| emb = ( |
| repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) |
| + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) |
| + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) |
| ) |
| assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" |
| else: |
| raise ValueError(f"Unknown interpolation method {self.interpolation}") |
|
|
| return normalize(emb, dim=-1, eps=1e-6) |
|
|
|
|
| class VideoRopePosition3DEmb(VideoPositionEmb): |
| def __init__( |
| self, |
| *, |
| head_dim: int, |
| len_h: int, |
| len_w: int, |
| len_t: int, |
| base_fps: int = 24, |
| h_extrapolation_ratio: float = 1.0, |
| w_extrapolation_ratio: float = 1.0, |
| t_extrapolation_ratio: float = 1.0, |
| enable_fps_modulation: bool = True, |
| device=None, |
| **kwargs, |
| ): |
| del kwargs |
| super().__init__() |
| self.base_fps = base_fps |
| self.max_h = len_h |
| self.max_w = len_w |
| self.enable_fps_modulation = enable_fps_modulation |
|
|
| dim = head_dim |
| dim_h = dim // 6 * 2 |
| dim_w = dim_h |
| dim_t = dim - 2 * dim_h |
| assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" |
| self.register_buffer( |
| "dim_spatial_range", |
| torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h, |
| persistent=False, |
| ) |
| self.register_buffer( |
| "dim_temporal_range", |
| torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t, |
| persistent=False, |
| ) |
|
|
| self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) |
| self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) |
| self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) |
|
|
| def generate_embeddings( |
| self, |
| B_T_H_W_C: torch.Size, |
| fps: Optional[torch.Tensor] = None, |
| h_ntk_factor: Optional[float] = None, |
| w_ntk_factor: Optional[float] = None, |
| t_ntk_factor: Optional[float] = None, |
| device=None, |
| dtype=None, |
| ): |
| """ |
| Generate embeddings for the given input size. |
| |
| Args: |
| B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). |
| fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. |
| h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. |
| w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. |
| t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. |
| |
| Returns: |
| Not specified in the original code snippet. |
| """ |
| h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor |
| w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor |
| t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor |
|
|
| h_theta = 10000.0 * h_ntk_factor |
| w_theta = 10000.0 * w_ntk_factor |
| t_theta = 10000.0 * t_ntk_factor |
|
|
| h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device)) |
| w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device)) |
| temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device)) |
|
|
| B, T, H, W, _ = B_T_H_W_C |
| seq = torch.arange(max(H, W, T), dtype=torch.float, device=device) |
| uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max()) |
| assert ( |
| uniform_fps or B == 1 or T == 1 |
| ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" |
| half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs) |
| half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs) |
|
|
| |
| if fps is None or self.enable_fps_modulation is False: |
| half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs) |
| else: |
| half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) |
|
|
| half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1) |
| half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1) |
| half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1) |
|
|
| em_T_H_W_D = torch.cat( |
| [ |
| repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W), |
| repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W), |
| repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H), |
| ] |
| , dim=-2, |
| ) |
|
|
| return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float() |
|
|
|
|
| def apply_rotary_pos_emb( |
| t: torch.Tensor, |
| freqs: torch.Tensor, |
| ) -> torch.Tensor: |
| t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() |
| t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] |
| t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) |
| return t_out |
|
|
|
|
| |
| class GPT2FeedForward(nn.Module): |
| def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None: |
| super().__init__() |
| self.activation = nn.GELU() |
| self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype) |
| self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype) |
|
|
| self._layer_id = None |
| self._dim = d_model |
| self._hidden_dim = d_ff |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.layer1(x) |
|
|
| x = self.activation(x) |
| x = self.layer2(x) |
| return x |
|
|
|
|
| def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: |
| """Computes multi-head attention using PyTorch's native implementation. |
| |
| This function provides a PyTorch backend alternative to Transformer Engine's attention operation. |
| It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product |
| attention, and rearranges the output back to the original format. |
| |
| The input tensor names use the following dimension conventions: |
| |
| - B: batch size |
| - S: sequence length |
| - H: number of attention heads |
| - D: head dimension |
| |
| Args: |
| q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim) |
| k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim) |
| v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim) |
| |
| Returns: |
| Attention output tensor with shape (batch, seq_len, n_heads * head_dim) |
| """ |
| in_q_shape = q_B_S_H_D.shape |
| in_k_shape = k_B_S_H_D.shape |
| q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) |
| k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) |
| v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) |
| return attention_forward(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, out_pattern="b s (n d)") |
|
|
|
|
| class Attention(nn.Module): |
| """ |
| A flexible attention module supporting both self-attention and cross-attention mechanisms. |
| |
| This module implements a multi-head attention layer that can operate in either self-attention |
| or cross-attention mode. The mode is determined by whether a context dimension is provided. |
| The implementation uses scaled dot-product attention and supports optional bias terms and |
| dropout regularization. |
| |
| Args: |
| query_dim (int): The dimensionality of the query vectors. |
| context_dim (int, optional): The dimensionality of the context (key/value) vectors. |
| If None, the module operates in self-attention mode using query_dim. Default: None |
| n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8 |
| head_dim (int, optional): The dimension of each attention head. Default: 64 |
| dropout (float, optional): Dropout probability applied to the output. Default: 0.0 |
| qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd" |
| backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine" |
| |
| Examples: |
| >>> # Self-attention with 512 dimensions and 8 heads |
| >>> self_attn = Attention(query_dim=512) |
| >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim) |
| >>> out = self_attn(x) # (32, 16, 512) |
| |
| >>> # Cross-attention |
| >>> cross_attn = Attention(query_dim=512, context_dim=256) |
| >>> query = torch.randn(32, 16, 512) |
| >>> context = torch.randn(32, 8, 256) |
| >>> out = cross_attn(query, context) # (32, 16, 512) |
| """ |
|
|
| def __init__( |
| self, |
| query_dim: int, |
| context_dim: Optional[int] = None, |
| n_heads: int = 8, |
| head_dim: int = 64, |
| dropout: float = 0.0, |
| device=None, |
| dtype=None, |
| operations=None, |
| ) -> None: |
| super().__init__() |
| logging.debug( |
| f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " |
| f"{n_heads} heads with a dimension of {head_dim}." |
| ) |
| self.is_selfattn = context_dim is None |
|
|
| context_dim = query_dim if context_dim is None else context_dim |
| inner_dim = head_dim * n_heads |
|
|
| self.n_heads = n_heads |
| self.head_dim = head_dim |
| self.query_dim = query_dim |
| self.context_dim = context_dim |
|
|
| self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) |
| self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) |
|
|
| self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) |
| self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) |
|
|
| self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) |
| self.v_norm = nn.Identity() |
|
|
| self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) |
| self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity() |
|
|
| self.attn_op = torch_attention_op |
|
|
| self._query_dim = query_dim |
| self._context_dim = context_dim |
| self._inner_dim = inner_dim |
|
|
| def compute_qkv( |
| self, |
| x: torch.Tensor, |
| context: Optional[torch.Tensor] = None, |
| rope_emb: Optional[torch.Tensor] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| q = self.q_proj(x) |
| context = x if context is None else context |
| k = self.k_proj(context) |
| v = self.v_proj(context) |
| q, k, v = map( |
| lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim), |
| (q, k, v), |
| ) |
|
|
| def apply_norm_and_rotary_pos_emb( |
| q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor] |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| q = self.q_norm(q) |
| k = self.k_norm(k) |
| v = self.v_norm(v) |
| if self.is_selfattn and rope_emb is not None: |
| q = apply_rotary_pos_emb(q, rope_emb) |
| k = apply_rotary_pos_emb(k, rope_emb) |
| return q, k, v |
|
|
| q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) |
|
|
| return q, k, v |
|
|
| def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor: |
| result = self.attn_op(q, k, v, transformer_options=transformer_options) |
| return self.output_dropout(self.output_proj(result)) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| context: Optional[torch.Tensor] = None, |
| rope_emb: Optional[torch.Tensor] = None, |
| transformer_options: Optional[dict] = {}, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x (Tensor): The query tensor of shape [B, Mq, K] |
| context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None |
| """ |
| q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) |
| return self.compute_attention(q, k, v, transformer_options=transformer_options) |
|
|
|
|
| class Timesteps(nn.Module): |
| def __init__(self, num_channels: int): |
| super().__init__() |
| self.num_channels = num_channels |
|
|
| def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor: |
| assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}" |
| timesteps = timesteps_B_T.flatten().float() |
| half_dim = self.num_channels // 2 |
| exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) |
| exponent = exponent / (half_dim - 0.0) |
|
|
| emb = torch.exp(exponent) |
| emb = timesteps[:, None].float() * emb[None, :] |
|
|
| sin_emb = torch.sin(emb) |
| cos_emb = torch.cos(emb) |
| emb = torch.cat([cos_emb, sin_emb], dim=-1) |
|
|
| return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1]) |
|
|
|
|
| class TimestepEmbedding(nn.Module): |
| def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None): |
| super().__init__() |
| logging.debug( |
| f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." |
| ) |
| self.in_dim = in_features |
| self.out_dim = out_features |
| self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype) |
| self.activation = nn.SiLU() |
| self.use_adaln_lora = use_adaln_lora |
| if use_adaln_lora: |
| self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype) |
| else: |
| self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype) |
|
|
| def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| emb = self.linear_1(sample) |
| emb = self.activation(emb) |
| emb = self.linear_2(emb) |
|
|
| if self.use_adaln_lora: |
| adaln_lora_B_T_3D = emb |
| emb_B_T_D = sample |
| else: |
| adaln_lora_B_T_3D = None |
| emb_B_T_D = emb |
|
|
| return emb_B_T_D, adaln_lora_B_T_3D |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """ |
| PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, |
| depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, |
| making it suitable for video and image processing tasks. It supports dividing the input into patches |
| and embedding each patch into a vector of size `out_channels`. |
| |
| Parameters: |
| - spatial_patch_size (int): The size of each spatial patch. |
| - temporal_patch_size (int): The size of each temporal patch. |
| - in_channels (int): Number of input channels. Default: 3. |
| - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. |
| - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. |
| """ |
|
|
| def __init__( |
| self, |
| spatial_patch_size: int, |
| temporal_patch_size: int, |
| in_channels: int = 3, |
| out_channels: int = 768, |
| device=None, dtype=None, operations=None |
| ): |
| super().__init__() |
| self.spatial_patch_size = spatial_patch_size |
| self.temporal_patch_size = temporal_patch_size |
|
|
| self.proj = nn.Sequential( |
| Rearrange( |
| "b c (t r) (h m) (w n) -> b t h w (c r m n)", |
| r=temporal_patch_size, |
| m=spatial_patch_size, |
| n=spatial_patch_size, |
| ), |
| operations.Linear( |
| in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype |
| ), |
| ) |
| self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass of the PatchEmbed module. |
| |
| Parameters: |
| - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where |
| B is the batch size, |
| C is the number of channels, |
| T is the temporal dimension, |
| H is the height, and |
| W is the width of the input. |
| |
| Returns: |
| - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. |
| """ |
| assert x.dim() == 5 |
| _, _, T, H, W = x.shape |
| assert ( |
| H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 |
| ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}" |
| assert T % self.temporal_patch_size == 0 |
| x = self.proj(x) |
| return x |
|
|
|
|
| class FinalLayer(nn.Module): |
| """ |
| The final layer of video DiT. |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| spatial_patch_size: int, |
| temporal_patch_size: int, |
| out_channels: int, |
| use_adaln_lora: bool = False, |
| adaln_lora_dim: int = 256, |
| device=None, dtype=None, operations=None |
| ): |
| super().__init__() |
| self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.linear = operations.Linear( |
| hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype |
| ) |
| self.hidden_size = hidden_size |
| self.n_adaln_chunks = 2 |
| self.use_adaln_lora = use_adaln_lora |
| self.adaln_lora_dim = adaln_lora_dim |
| if use_adaln_lora: |
| self.adaln_modulation = nn.Sequential( |
| nn.SiLU(), |
| operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype), |
| operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype), |
| ) |
| else: |
| self.adaln_modulation = nn.Sequential( |
| nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype) |
| ) |
|
|
| def forward( |
| self, |
| x_B_T_H_W_D: torch.Tensor, |
| emb_B_T_D: torch.Tensor, |
| adaln_lora_B_T_3D: Optional[torch.Tensor] = None, |
| ): |
| if self.use_adaln_lora: |
| assert adaln_lora_B_T_3D is not None |
| shift_B_T_D, scale_B_T_D = ( |
| self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] |
| ).chunk(2, dim=-1) |
| else: |
| shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) |
|
|
| shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange( |
| scale_B_T_D, "b t d -> b t 1 1 d" |
| ) |
|
|
| def _fn( |
| _x_B_T_H_W_D: torch.Tensor, |
| _norm_layer: nn.Module, |
| _scale_B_T_1_1_D: torch.Tensor, |
| _shift_B_T_1_1_D: torch.Tensor, |
| ) -> torch.Tensor: |
| return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D |
|
|
| x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D) |
| x_B_T_H_W_O = self.linear(x_B_T_H_W_D) |
| return x_B_T_H_W_O |
|
|
|
|
| class Block(nn.Module): |
| """ |
| A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation. |
| Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation. |
| |
| Parameters: |
| x_dim (int): Dimension of input features |
| context_dim (int): Dimension of context features for cross-attention |
| num_heads (int): Number of attention heads |
| mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0 |
| use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False |
| adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256 |
| |
| The block applies the following sequence: |
| 1. Self-attention with AdaLN modulation |
| 2. Cross-attention with AdaLN modulation |
| 3. MLP with AdaLN modulation |
| |
| Each component uses skip connections and layer normalization. |
| """ |
|
|
| def __init__( |
| self, |
| x_dim: int, |
| context_dim: int, |
| num_heads: int, |
| mlp_ratio: float = 4.0, |
| use_adaln_lora: bool = False, |
| adaln_lora_dim: int = 256, |
| device=None, |
| dtype=None, |
| operations=None, |
| ): |
| super().__init__() |
| self.x_dim = x_dim |
| self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) |
| self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations) |
|
|
| self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) |
| self.cross_attn = Attention( |
| x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations |
| ) |
|
|
| self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) |
| self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations) |
|
|
| self.use_adaln_lora = use_adaln_lora |
| if self.use_adaln_lora: |
| self.adaln_modulation_self_attn = nn.Sequential( |
| nn.SiLU(), |
| operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), |
| operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), |
| ) |
| self.adaln_modulation_cross_attn = nn.Sequential( |
| nn.SiLU(), |
| operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), |
| operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), |
| ) |
| self.adaln_modulation_mlp = nn.Sequential( |
| nn.SiLU(), |
| operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), |
| operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), |
| ) |
| else: |
| self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) |
| self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) |
| self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) |
|
|
| def forward( |
| self, |
| x_B_T_H_W_D: torch.Tensor, |
| emb_B_T_D: torch.Tensor, |
| crossattn_emb: torch.Tensor, |
| rope_emb_L_1_1_D: Optional[torch.Tensor] = None, |
| adaln_lora_B_T_3D: Optional[torch.Tensor] = None, |
| extra_per_block_pos_emb: Optional[torch.Tensor] = None, |
| transformer_options: Optional[dict] = {}, |
| ) -> torch.Tensor: |
| residual_dtype = x_B_T_H_W_D.dtype |
| compute_dtype = emb_B_T_D.dtype |
| if extra_per_block_pos_emb is not None: |
| x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb |
|
|
| if self.use_adaln_lora: |
| shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( |
| self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D |
| ).chunk(3, dim=-1) |
| shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( |
| self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D |
| ).chunk(3, dim=-1) |
| shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( |
| self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D |
| ).chunk(3, dim=-1) |
| else: |
| shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( |
| emb_B_T_D |
| ).chunk(3, dim=-1) |
| shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( |
| emb_B_T_D |
| ).chunk(3, dim=-1) |
| shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) |
|
|
| |
| shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") |
| scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d") |
| gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d") |
|
|
| shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d") |
| scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d") |
| gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d") |
|
|
| shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d") |
| scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d") |
| gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d") |
|
|
| B, T, H, W, D = x_B_T_H_W_D.shape |
|
|
| def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): |
| return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D |
|
|
| normalized_x_B_T_H_W_D = _fn( |
| x_B_T_H_W_D, |
| self.layer_norm_self_attn, |
| scale_self_attn_B_T_1_1_D, |
| shift_self_attn_B_T_1_1_D, |
| ) |
| result_B_T_H_W_D = rearrange( |
| self.self_attn( |
| |
| rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), |
| None, |
| rope_emb=rope_emb_L_1_1_D, |
| transformer_options=transformer_options, |
| ), |
| "b (t h w) d -> b t h w d", |
| t=T, |
| h=H, |
| w=W, |
| ) |
| x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) |
|
|
| def _x_fn( |
| _x_B_T_H_W_D: torch.Tensor, |
| layer_norm_cross_attn: Callable, |
| _scale_cross_attn_B_T_1_1_D: torch.Tensor, |
| _shift_cross_attn_B_T_1_1_D: torch.Tensor, |
| transformer_options: Optional[dict] = {}, |
| ) -> torch.Tensor: |
| _normalized_x_B_T_H_W_D = _fn( |
| _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D |
| ) |
| _result_B_T_H_W_D = rearrange( |
| self.cross_attn( |
| rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), |
| crossattn_emb, |
| rope_emb=rope_emb_L_1_1_D, |
| transformer_options=transformer_options, |
| ), |
| "b (t h w) d -> b t h w d", |
| t=T, |
| h=H, |
| w=W, |
| ) |
| return _result_B_T_H_W_D |
|
|
| result_B_T_H_W_D = _x_fn( |
| x_B_T_H_W_D, |
| self.layer_norm_cross_attn, |
| scale_cross_attn_B_T_1_1_D, |
| shift_cross_attn_B_T_1_1_D, |
| transformer_options=transformer_options, |
| ) |
| x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D |
|
|
| normalized_x_B_T_H_W_D = _fn( |
| x_B_T_H_W_D, |
| self.layer_norm_mlp, |
| scale_mlp_B_T_1_1_D, |
| shift_mlp_B_T_1_1_D, |
| ) |
| result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) |
| x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) |
| return x_B_T_H_W_D |
|
|
|
|
| class MiniTrainDIT(nn.Module): |
| """ |
| A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1) |
| A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. |
| |
| Args: |
| max_img_h (int): Maximum height of the input images. |
| max_img_w (int): Maximum width of the input images. |
| max_frames (int): Maximum number of frames in the video sequence. |
| in_channels (int): Number of input channels (e.g., RGB channels for color images). |
| out_channels (int): Number of output channels. |
| patch_spatial (tuple): Spatial resolution of patches for input processing. |
| patch_temporal (int): Temporal resolution of patches for input processing. |
| concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. |
| model_channels (int): Base number of channels used throughout the model. |
| num_blocks (int): Number of transformer blocks. |
| num_heads (int): Number of heads in the multi-head attention layers. |
| mlp_ratio (float): Expansion ratio for MLP blocks. |
| crossattn_emb_channels (int): Number of embedding channels for cross-attention. |
| pos_emb_cls (str): Type of positional embeddings. |
| pos_emb_learnable (bool): Whether positional embeddings are learnable. |
| pos_emb_interpolation (str): Method for interpolating positional embeddings. |
| min_fps (int): Minimum frames per second. |
| max_fps (int): Maximum frames per second. |
| use_adaln_lora (bool): Whether to use AdaLN-LoRA. |
| adaln_lora_dim (int): Dimension for AdaLN-LoRA. |
| rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. |
| rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. |
| rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. |
| extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. |
| extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. |
| extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. |
| extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. |
| """ |
|
|
| def __init__( |
| self, |
| max_img_h: int, |
| max_img_w: int, |
| max_frames: int, |
| in_channels: int, |
| out_channels: int, |
| patch_spatial: int, |
| patch_temporal: int, |
| concat_padding_mask: bool = True, |
| |
| model_channels: int = 768, |
| num_blocks: int = 10, |
| num_heads: int = 16, |
| mlp_ratio: float = 4.0, |
| |
| crossattn_emb_channels: int = 1024, |
| |
| pos_emb_cls: str = "sincos", |
| pos_emb_learnable: bool = False, |
| pos_emb_interpolation: str = "crop", |
| min_fps: int = 1, |
| max_fps: int = 30, |
| use_adaln_lora: bool = False, |
| adaln_lora_dim: int = 256, |
| rope_h_extrapolation_ratio: float = 1.0, |
| rope_w_extrapolation_ratio: float = 1.0, |
| rope_t_extrapolation_ratio: float = 1.0, |
| extra_per_block_abs_pos_emb: bool = False, |
| extra_h_extrapolation_ratio: float = 1.0, |
| extra_w_extrapolation_ratio: float = 1.0, |
| extra_t_extrapolation_ratio: float = 1.0, |
| rope_enable_fps_modulation: bool = True, |
| image_model=None, |
| device=None, |
| dtype=None, |
| operations=None, |
| ) -> None: |
| super().__init__() |
| self.dtype = dtype |
| self.max_img_h = max_img_h |
| self.max_img_w = max_img_w |
| self.max_frames = max_frames |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.patch_spatial = patch_spatial |
| self.patch_temporal = patch_temporal |
| self.num_heads = num_heads |
| self.num_blocks = num_blocks |
| self.model_channels = model_channels |
| self.concat_padding_mask = concat_padding_mask |
| |
| self.pos_emb_cls = pos_emb_cls |
| self.pos_emb_learnable = pos_emb_learnable |
| self.pos_emb_interpolation = pos_emb_interpolation |
| self.min_fps = min_fps |
| self.max_fps = max_fps |
| self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio |
| self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio |
| self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio |
| self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb |
| self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio |
| self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio |
| self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio |
| self.rope_enable_fps_modulation = rope_enable_fps_modulation |
|
|
| self.build_pos_embed(device=device, dtype=dtype) |
| self.use_adaln_lora = use_adaln_lora |
| self.adaln_lora_dim = adaln_lora_dim |
| self.t_embedder = nn.Sequential( |
| Timesteps(model_channels), |
| TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,), |
| ) |
|
|
| in_channels = in_channels + 1 if concat_padding_mask else in_channels |
| self.x_embedder = PatchEmbed( |
| spatial_patch_size=patch_spatial, |
| temporal_patch_size=patch_temporal, |
| in_channels=in_channels, |
| out_channels=model_channels, |
| device=device, dtype=dtype, operations=operations, |
| ) |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| Block( |
| x_dim=model_channels, |
| context_dim=crossattn_emb_channels, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| use_adaln_lora=use_adaln_lora, |
| adaln_lora_dim=adaln_lora_dim, |
| device=device, dtype=dtype, operations=operations, |
| ) |
| for _ in range(num_blocks) |
| ] |
| ) |
|
|
| self.final_layer = FinalLayer( |
| hidden_size=self.model_channels, |
| spatial_patch_size=self.patch_spatial, |
| temporal_patch_size=self.patch_temporal, |
| out_channels=self.out_channels, |
| use_adaln_lora=self.use_adaln_lora, |
| adaln_lora_dim=self.adaln_lora_dim, |
| device=device, dtype=dtype, operations=operations, |
| ) |
|
|
| self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype) |
|
|
| def build_pos_embed(self, device=None, dtype=None) -> None: |
| if self.pos_emb_cls == "rope3d": |
| cls_type = VideoRopePosition3DEmb |
| else: |
| raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") |
|
|
| logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") |
| kwargs = dict( |
| model_channels=self.model_channels, |
| len_h=self.max_img_h // self.patch_spatial, |
| len_w=self.max_img_w // self.patch_spatial, |
| len_t=self.max_frames // self.patch_temporal, |
| max_fps=self.max_fps, |
| min_fps=self.min_fps, |
| is_learnable=self.pos_emb_learnable, |
| interpolation=self.pos_emb_interpolation, |
| head_dim=self.model_channels // self.num_heads, |
| h_extrapolation_ratio=self.rope_h_extrapolation_ratio, |
| w_extrapolation_ratio=self.rope_w_extrapolation_ratio, |
| t_extrapolation_ratio=self.rope_t_extrapolation_ratio, |
| enable_fps_modulation=self.rope_enable_fps_modulation, |
| device=device, |
| ) |
| self.pos_embedder = cls_type( |
| **kwargs, |
| ) |
|
|
| if self.extra_per_block_abs_pos_emb: |
| kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio |
| kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio |
| kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio |
| kwargs["device"] = device |
| kwargs["dtype"] = dtype |
| self.extra_pos_embedder = LearnablePosEmbAxis( |
| **kwargs, |
| ) |
|
|
| def prepare_embedded_sequence( |
| self, |
| x_B_C_T_H_W: torch.Tensor, |
| fps: Optional[torch.Tensor] = None, |
| padding_mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: |
| """ |
| Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. |
| |
| Args: |
| x_B_C_T_H_W (torch.Tensor): video |
| fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. |
| If None, a default value (`self.base_fps`) will be used. |
| padding_mask (Optional[torch.Tensor]): current it is not used |
| |
| Returns: |
| Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| - A tensor of shape (B, T, H, W, D) with the embedded sequence. |
| - An optional positional embedding tensor, returned only if the positional embedding class |
| (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. |
| |
| Notes: |
| - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. |
| - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. |
| - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using |
| the `self.pos_embedder` with the shape [T, H, W]. |
| - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the |
| `self.pos_embedder` with the fps tensor. |
| - Otherwise, the positional embeddings are generated without considering fps. |
| """ |
| if self.concat_padding_mask: |
| if padding_mask is None: |
| padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device) |
| else: |
| padding_mask = transforms.functional.resize( |
| padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST |
| ) |
| x_B_C_T_H_W = torch.cat( |
| [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 |
| ) |
| x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) |
|
|
| if self.extra_per_block_abs_pos_emb: |
| extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype) |
| else: |
| extra_pos_emb = None |
|
|
| if "rope" in self.pos_emb_cls.lower(): |
| return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb |
| x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) |
|
|
| return x_B_T_H_W_D, None, extra_pos_emb |
|
|
| def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor: |
| x_B_C_Tt_Hp_Wp = rearrange( |
| x_B_T_H_W_M, |
| "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)", |
| p1=self.patch_spatial, |
| p2=self.patch_spatial, |
| t=self.patch_temporal, |
| ) |
| return x_B_C_Tt_Hp_Wp |
| |
| def pad_to_patch_size(self, img, patch_size=(2, 2), padding_mode="circular"): |
| if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()): |
| padding_mode = "reflect" |
|
|
| pad = () |
| for i in range(img.ndim - 2): |
| pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad |
|
|
| return torch.nn.functional.pad(img, pad, mode=padding_mode) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| timesteps: torch.Tensor, |
| context: torch.Tensor, |
| fps: Optional[torch.Tensor] = None, |
| padding_mask: Optional[torch.Tensor] = None, |
| use_gradient_checkpointing=False, |
| use_gradient_checkpointing_offload=False, |
| **kwargs, |
| ): |
| orig_shape = list(x.shape) |
| x = self.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial)) |
| x_B_C_T_H_W = x |
| timesteps_B_T = timesteps |
| crossattn_emb = context |
| """ |
| Args: |
| x: (B, C, T, H, W) tensor of spatial-temp inputs |
| timesteps: (B, ) tensor of timesteps |
| crossattn_emb: (B, N, D) tensor of cross-attention embeddings |
| """ |
| x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( |
| x_B_C_T_H_W, |
| fps=fps, |
| padding_mask=padding_mask, |
| ) |
|
|
| if timesteps_B_T.ndim == 1: |
| timesteps_B_T = timesteps_B_T.unsqueeze(1) |
| t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype)) |
| t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D) |
|
|
| |
| affline_scale_log_info = {} |
| affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach() |
| self.affline_scale_log_info = affline_scale_log_info |
| self.affline_emb = t_embedding_B_T_D |
| self.crossattn_emb = crossattn_emb |
|
|
| if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: |
| assert ( |
| x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape |
| ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}" |
|
|
| block_kwargs = { |
| "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), |
| "adaln_lora_B_T_3D": adaln_lora_B_T_3D, |
| "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, |
| "transformer_options": kwargs.get("transformer_options", {}), |
| } |
|
|
| |
| |
| |
| |
| if x_B_T_H_W_D.dtype == torch.float16: |
| x_B_T_H_W_D = x_B_T_H_W_D.float() |
|
|
| for block in self.blocks: |
| x_B_T_H_W_D = gradient_checkpoint_forward( |
| block, |
| use_gradient_checkpointing=use_gradient_checkpointing, |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, |
| x_B_T_H_W_D=x_B_T_H_W_D, |
| emb_B_T_D=t_embedding_B_T_D, |
| crossattn_emb=crossattn_emb, |
| **block_kwargs, |
| ) |
|
|
| x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) |
| x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]] |
| return x_B_C_Tt_Hp_Wp |
|
|
|
|
| def rotate_half(x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb2(x, cos, sin, unsqueeze_dim=1): |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| x_embed = (x * cos) + (rotate_half(x) * sin) |
| return x_embed |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, head_dim): |
| super().__init__() |
| self.rope_theta = 10000 |
| inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| @torch.no_grad() |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() |
| sin = emb.sin() |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| class LLMAdapterAttention(nn.Module): |
| def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None): |
| super().__init__() |
|
|
| inner_dim = head_dim * n_heads |
| self.n_heads = n_heads |
| self.head_dim = head_dim |
| self.query_dim = query_dim |
| self.context_dim = context_dim |
|
|
| self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) |
| self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) |
|
|
| self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) |
| self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) |
|
|
| self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) |
|
|
| self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) |
|
|
| def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None): |
| context = x if context is None else context |
| input_shape = x.shape[:-1] |
| q_shape = (*input_shape, self.n_heads, self.head_dim) |
| context_shape = context.shape[:-1] |
| kv_shape = (*context_shape, self.n_heads, self.head_dim) |
|
|
| query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2) |
| key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2) |
| value_states = self.v_proj(context).view(kv_shape).transpose(1, 2) |
|
|
| if position_embeddings is not None: |
| assert position_embeddings_context is not None |
| cos, sin = position_embeddings |
| query_states = apply_rotary_pos_emb2(query_states, cos, sin) |
| cos, sin = position_embeddings_context |
| key_states = apply_rotary_pos_emb2(key_states, cos, sin) |
|
|
| attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) |
|
|
| attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output |
|
|
| def init_weights(self): |
| torch.nn.init.zeros_(self.o_proj.weight) |
|
|
|
|
| class LLMAdapterTransformerBlock(nn.Module): |
| def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None): |
| super().__init__() |
| self.use_self_attn = use_self_attn |
|
|
| if self.use_self_attn: |
| self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) |
| self.self_attn = LLMAdapterAttention( |
| query_dim=model_dim, |
| context_dim=model_dim, |
| n_heads=num_heads, |
| head_dim=model_dim//num_heads, |
| device=device, |
| dtype=dtype, |
| operations=operations, |
| ) |
|
|
| self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) |
| self.cross_attn = LLMAdapterAttention( |
| query_dim=model_dim, |
| context_dim=source_dim, |
| n_heads=num_heads, |
| head_dim=model_dim//num_heads, |
| device=device, |
| dtype=dtype, |
| operations=operations, |
| ) |
|
|
| self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) |
| self.mlp = nn.Sequential( |
| operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype), |
| nn.GELU(), |
| operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype) |
| ) |
|
|
| def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None): |
| if self.use_self_attn: |
| normed = self.norm_self_attn(x) |
| attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings) |
| x = x + attn_out |
|
|
| normed = self.norm_cross_attn(x) |
| attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) |
| x = x + attn_out |
|
|
| x = x + self.mlp(self.norm_mlp(x)) |
| return x |
|
|
| def init_weights(self): |
| torch.nn.init.zeros_(self.mlp[2].weight) |
| self.cross_attn.init_weights() |
|
|
|
|
| class LLMAdapter(nn.Module): |
| def __init__( |
| self, |
| source_dim=1024, |
| target_dim=1024, |
| model_dim=1024, |
| num_layers=6, |
| num_heads=16, |
| use_self_attn=True, |
| layer_norm=False, |
| device=None, |
| dtype=None, |
| operations=None, |
| ): |
| super().__init__() |
|
|
| self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype) |
| if model_dim != target_dim: |
| self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype) |
| else: |
| self.in_proj = nn.Identity() |
| self.rotary_emb = RotaryEmbedding(model_dim//num_heads) |
| self.blocks = nn.ModuleList([ |
| LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers) |
| ]) |
| self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype) |
| self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype) |
|
|
| def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None): |
| if target_attention_mask is not None: |
| target_attention_mask = target_attention_mask.to(torch.bool) |
| if target_attention_mask.ndim == 2: |
| target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1) |
|
|
| if source_attention_mask is not None: |
| source_attention_mask = source_attention_mask.to(torch.bool) |
| if source_attention_mask.ndim == 2: |
| source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) |
|
|
| context = source_hidden_states |
| x = self.in_proj(self.embed(target_input_ids).to(context.dtype)) |
| position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) |
| position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) |
| position_embeddings = self.rotary_emb(x, position_ids) |
| position_embeddings_context = self.rotary_emb(x, position_ids_context) |
| for block in self.blocks: |
| x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) |
| return self.norm(self.out_proj(x)) |
|
|
|
|
| class AnimaDiT(MiniTrainDIT): |
|
|
| _repeated_blocks = ["Block"] |
|
|
| def __init__(self): |
| kwargs = {'image_model': 'anima', 'max_img_h': 240, 'max_img_w': 240, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'dtype': torch.bfloat16, 'device': None, 'operations': torch.nn} |
| super().__init__(**kwargs) |
| self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) |
|
|
| def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None): |
| if text_ids is not None: |
| out = self.llm_adapter(text_embeds, text_ids) |
| if t5xxl_weights is not None: |
| out = out * t5xxl_weights |
|
|
| if out.shape[1] < 512: |
| out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1])) |
| return out |
| else: |
| return text_embeds |
|
|
| def forward( |
| self, |
| x, timesteps, context, |
| use_gradient_checkpointing=False, |
| use_gradient_checkpointing_offload=False, |
| **kwargs |
| ): |
| t5xxl_ids = kwargs.pop("t5xxl_ids", None) |
| if t5xxl_ids is not None: |
| context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None)) |
| return super().forward( |
| x, timesteps, context, |
| use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, |
| **kwargs |
| ) |
|
|