| | import logging
|
| | import math
|
| | from typing import Dict, Optional
|
| |
|
| | import numpy as np
|
| | import torch
|
| | import torch.nn as nn
|
| | from .. import attention
|
| | from einops import rearrange, repeat
|
| | from .util import timestep_embedding
|
| | import comfy.ops
|
| | import comfy.ldm.common_dit
|
| |
|
| | def default(x, y):
|
| | if x is not None:
|
| | return x
|
| | return y
|
| |
|
| | class Mlp(nn.Module):
|
| | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
|
| | """
|
| | def __init__(
|
| | self,
|
| | in_features,
|
| | hidden_features=None,
|
| | out_features=None,
|
| | act_layer=nn.GELU,
|
| | norm_layer=None,
|
| | bias=True,
|
| | drop=0.,
|
| | use_conv=False,
|
| | dtype=None,
|
| | device=None,
|
| | operations=None,
|
| | ):
|
| | super().__init__()
|
| | out_features = out_features or in_features
|
| | hidden_features = hidden_features or in_features
|
| | drop_probs = drop
|
| | linear_layer = partial(operations.Conv2d, kernel_size=1) if use_conv else operations.Linear
|
| |
|
| | self.fc1 = linear_layer(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
| | self.act = act_layer()
|
| | self.drop1 = nn.Dropout(drop_probs)
|
| | self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
| | self.fc2 = linear_layer(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
| | self.drop2 = nn.Dropout(drop_probs)
|
| |
|
| | def forward(self, x):
|
| | x = self.fc1(x)
|
| | x = self.act(x)
|
| | x = self.drop1(x)
|
| | x = self.norm(x)
|
| | x = self.fc2(x)
|
| | x = self.drop2(x)
|
| | return x
|
| |
|
| | class PatchEmbed(nn.Module):
|
| | """ 2D Image to Patch Embedding
|
| | """
|
| | dynamic_img_pad: torch.jit.Final[bool]
|
| |
|
| | def __init__(
|
| | self,
|
| | img_size: Optional[int] = 224,
|
| | patch_size: int = 16,
|
| | in_chans: int = 3,
|
| | embed_dim: int = 768,
|
| | norm_layer = None,
|
| | flatten: bool = True,
|
| | bias: bool = True,
|
| | strict_img_size: bool = True,
|
| | dynamic_img_pad: bool = True,
|
| | padding_mode='circular',
|
| | dtype=None,
|
| | device=None,
|
| | operations=None,
|
| | ):
|
| | super().__init__()
|
| | self.patch_size = (patch_size, patch_size)
|
| | self.padding_mode = padding_mode
|
| | if img_size is not None:
|
| | self.img_size = (img_size, img_size)
|
| | self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
| | self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| | else:
|
| | self.img_size = None
|
| | self.grid_size = None
|
| | self.num_patches = None
|
| |
|
| |
|
| | self.flatten = flatten
|
| | self.strict_img_size = strict_img_size
|
| | self.dynamic_img_pad = dynamic_img_pad
|
| |
|
| | self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
| | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| |
|
| | def forward(self, x):
|
| | B, C, H, W = x.shape
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if self.dynamic_img_pad:
|
| | x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
|
| | x = self.proj(x)
|
| | if self.flatten:
|
| | x = x.flatten(2).transpose(1, 2)
|
| | x = self.norm(x)
|
| | return x
|
| |
|
| | def modulate(x, shift, scale):
|
| | if shift is None:
|
| | shift = torch.zeros_like(scale)
|
| | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def get_2d_sincos_pos_embed(
|
| | embed_dim,
|
| | grid_size,
|
| | cls_token=False,
|
| | extra_tokens=0,
|
| | scaling_factor=None,
|
| | offset=None,
|
| | ):
|
| | """
|
| | grid_size: int of the grid height and width
|
| | return:
|
| | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| | """
|
| | grid_h = np.arange(grid_size, dtype=np.float32)
|
| | grid_w = np.arange(grid_size, dtype=np.float32)
|
| | grid = np.meshgrid(grid_w, grid_h)
|
| | grid = np.stack(grid, axis=0)
|
| | if scaling_factor is not None:
|
| | grid = grid / scaling_factor
|
| | if offset is not None:
|
| | grid = grid - offset
|
| |
|
| | grid = grid.reshape([2, 1, grid_size, grid_size])
|
| | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| | if cls_token and extra_tokens > 0:
|
| | pos_embed = np.concatenate(
|
| | [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
| | )
|
| | return pos_embed
|
| |
|
| |
|
| | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| | assert embed_dim % 2 == 0
|
| |
|
| |
|
| | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
|
| | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
|
| |
|
| | emb = np.concatenate([emb_h, emb_w], axis=1)
|
| | return emb
|
| |
|
| |
|
| | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| | """
|
| | embed_dim: output dimension for each position
|
| | pos: a list of positions to be encoded: size (M,)
|
| | out: (M, D)
|
| | """
|
| | assert embed_dim % 2 == 0
|
| | omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| | omega /= embed_dim / 2.0
|
| | omega = 1.0 / 10000**omega
|
| |
|
| | pos = pos.reshape(-1)
|
| | out = np.einsum("m,d->md", pos, omega)
|
| |
|
| | emb_sin = np.sin(out)
|
| | emb_cos = np.cos(out)
|
| |
|
| | emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
| | return emb
|
| |
|
| | def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=torch.float32):
|
| | omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
|
| | omega /= embed_dim / 2.0
|
| | omega = 1.0 / 10000**omega
|
| | pos = pos.reshape(-1)
|
| | out = torch.einsum("m,d->md", pos, omega)
|
| | emb_sin = torch.sin(out)
|
| | emb_cos = torch.cos(out)
|
| | emb = torch.cat([emb_sin, emb_cos], dim=1)
|
| | return emb
|
| |
|
| | def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude=7.5, device=None, dtype=torch.float32):
|
| | small = min(h, w)
|
| | val_h = (h / small) * val_magnitude
|
| | val_w = (w / small) * val_magnitude
|
| | grid_h, grid_w = torch.meshgrid(torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), indexing='ij')
|
| | emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
| | emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
| | emb = torch.cat([emb_w, emb_h], dim=1)
|
| | return emb
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class TimestepEmbedder(nn.Module):
|
| | """
|
| | Embeds scalar timesteps into vector representations.
|
| | """
|
| |
|
| | def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
| | super().__init__()
|
| | self.mlp = nn.Sequential(
|
| | operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
| | nn.SiLU(),
|
| | operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
| | )
|
| | self.frequency_embedding_size = frequency_embedding_size
|
| |
|
| | def forward(self, t, dtype, **kwargs):
|
| | t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
| | t_emb = self.mlp(t_freq)
|
| | return t_emb
|
| |
|
| |
|
| | class VectorEmbedder(nn.Module):
|
| | """
|
| | Embeds a flat vector of dimension input_dim
|
| | """
|
| |
|
| | def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None, operations=None):
|
| | super().__init__()
|
| | self.mlp = nn.Sequential(
|
| | operations.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
| | nn.SiLU(),
|
| | operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
| | )
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | emb = self.mlp(x)
|
| | return emb
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def split_qkv(qkv, head_dim):
|
| | qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
| | return qkv[0], qkv[1], qkv[2]
|
| |
|
| | def optimized_attention(qkv, num_heads):
|
| | return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
|
| |
|
| | class SelfAttention(nn.Module):
|
| | ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
| |
|
| | def __init__(
|
| | self,
|
| | dim: int,
|
| | num_heads: int = 8,
|
| | qkv_bias: bool = False,
|
| | qk_scale: Optional[float] = None,
|
| | proj_drop: float = 0.0,
|
| | attn_mode: str = "xformers",
|
| | pre_only: bool = False,
|
| | qk_norm: Optional[str] = None,
|
| | rmsnorm: bool = False,
|
| | dtype=None,
|
| | device=None,
|
| | operations=None,
|
| | ):
|
| | super().__init__()
|
| | self.num_heads = num_heads
|
| | self.head_dim = dim // num_heads
|
| |
|
| | self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
| | if not pre_only:
|
| | self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
| | self.proj_drop = nn.Dropout(proj_drop)
|
| | assert attn_mode in self.ATTENTION_MODES
|
| | self.attn_mode = attn_mode
|
| | self.pre_only = pre_only
|
| |
|
| | if qk_norm == "rms":
|
| | self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
| | self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
| | elif qk_norm == "ln":
|
| | self.ln_q = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
| | self.ln_k = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
| | elif qk_norm is None:
|
| | self.ln_q = nn.Identity()
|
| | self.ln_k = nn.Identity()
|
| | else:
|
| | raise ValueError(qk_norm)
|
| |
|
| | def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
|
| | B, L, C = x.shape
|
| | qkv = self.qkv(x)
|
| | q, k, v = split_qkv(qkv, self.head_dim)
|
| | q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
| | k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
| | return (q, k, v)
|
| |
|
| | def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
| | assert not self.pre_only
|
| | x = self.proj(x)
|
| | x = self.proj_drop(x)
|
| | return x
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | qkv = self.pre_attention(x)
|
| | x = optimized_attention(
|
| | qkv, num_heads=self.num_heads
|
| | )
|
| | x = self.post_attention(x)
|
| | return x
|
| |
|
| |
|
| | class RMSNorm(torch.nn.Module):
|
| | def __init__(
|
| | self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
|
| | ):
|
| | """
|
| | Initialize the RMSNorm normalization layer.
|
| | Args:
|
| | dim (int): The dimension of the input tensor.
|
| | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
| | Attributes:
|
| | eps (float): A small value added to the denominator for numerical stability.
|
| | weight (nn.Parameter): Learnable scaling parameter.
|
| | """
|
| | super().__init__()
|
| | self.eps = eps
|
| | self.learnable_scale = elementwise_affine
|
| | if self.learnable_scale:
|
| | self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
| | else:
|
| | self.register_parameter("weight", None)
|
| |
|
| | def _norm(self, x):
|
| | """
|
| | Apply the RMSNorm normalization to the input tensor.
|
| | Args:
|
| | x (torch.Tensor): The input tensor.
|
| | Returns:
|
| | torch.Tensor: The normalized tensor.
|
| | """
|
| | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| |
|
| | def forward(self, x):
|
| | """
|
| | Forward pass through the RMSNorm layer.
|
| | Args:
|
| | x (torch.Tensor): The input tensor.
|
| | Returns:
|
| | torch.Tensor: The output tensor after applying RMSNorm.
|
| | """
|
| | x = self._norm(x)
|
| | if self.learnable_scale:
|
| | return x * self.weight.to(device=x.device, dtype=x.dtype)
|
| | else:
|
| | return x
|
| |
|
| |
|
| | class SwiGLUFeedForward(nn.Module):
|
| | def __init__(
|
| | self,
|
| | dim: int,
|
| | hidden_dim: int,
|
| | multiple_of: int,
|
| | ffn_dim_multiplier: Optional[float] = None,
|
| | ):
|
| | """
|
| | Initialize the FeedForward module.
|
| |
|
| | Args:
|
| | dim (int): Input dimension.
|
| | hidden_dim (int): Hidden dimension of the feedforward layer.
|
| | multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
| | ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
| |
|
| | Attributes:
|
| | w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
| | w2 (RowParallelLinear): Linear transformation for the second layer.
|
| | w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
| |
|
| | """
|
| | super().__init__()
|
| | hidden_dim = int(2 * hidden_dim / 3)
|
| |
|
| | if ffn_dim_multiplier is not None:
|
| | hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| |
|
| | self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
| | self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
| | self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
| |
|
| | def forward(self, x):
|
| | return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
| |
|
| |
|
| | class DismantledBlock(nn.Module):
|
| | """
|
| | A DiT block with gated adaptive layer norm (adaLN) conditioning.
|
| | """
|
| |
|
| | ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
| |
|
| | def __init__(
|
| | self,
|
| | hidden_size: int,
|
| | num_heads: int,
|
| | mlp_ratio: float = 4.0,
|
| | attn_mode: str = "xformers",
|
| | qkv_bias: bool = False,
|
| | pre_only: bool = False,
|
| | rmsnorm: bool = False,
|
| | scale_mod_only: bool = False,
|
| | swiglu: bool = False,
|
| | qk_norm: Optional[str] = None,
|
| | dtype=None,
|
| | device=None,
|
| | operations=None,
|
| | **block_kwargs,
|
| | ):
|
| | super().__init__()
|
| | assert attn_mode in self.ATTENTION_MODES
|
| | if not rmsnorm:
|
| | self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| | else:
|
| | self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| | self.attn = SelfAttention(
|
| | dim=hidden_size,
|
| | num_heads=num_heads,
|
| | qkv_bias=qkv_bias,
|
| | attn_mode=attn_mode,
|
| | pre_only=pre_only,
|
| | qk_norm=qk_norm,
|
| | rmsnorm=rmsnorm,
|
| | dtype=dtype,
|
| | device=device,
|
| | operations=operations
|
| | )
|
| | if not pre_only:
|
| | if not rmsnorm:
|
| | self.norm2 = operations.LayerNorm(
|
| | hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
| | )
|
| | else:
|
| | self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| | mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| | if not pre_only:
|
| | if not swiglu:
|
| | self.mlp = Mlp(
|
| | in_features=hidden_size,
|
| | hidden_features=mlp_hidden_dim,
|
| | act_layer=lambda: nn.GELU(approximate="tanh"),
|
| | drop=0,
|
| | dtype=dtype,
|
| | device=device,
|
| | operations=operations
|
| | )
|
| | else:
|
| | self.mlp = SwiGLUFeedForward(
|
| | dim=hidden_size,
|
| | hidden_dim=mlp_hidden_dim,
|
| | multiple_of=256,
|
| | )
|
| | self.scale_mod_only = scale_mod_only
|
| | if not scale_mod_only:
|
| | n_mods = 6 if not pre_only else 2
|
| | else:
|
| | n_mods = 4 if not pre_only else 1
|
| | self.adaLN_modulation = nn.Sequential(
|
| | nn.SiLU(), operations.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device)
|
| | )
|
| | self.pre_only = pre_only
|
| |
|
| | def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| | if not self.pre_only:
|
| | if not self.scale_mod_only:
|
| | (
|
| | shift_msa,
|
| | scale_msa,
|
| | gate_msa,
|
| | shift_mlp,
|
| | scale_mlp,
|
| | gate_mlp,
|
| | ) = self.adaLN_modulation(c).chunk(6, dim=1)
|
| | else:
|
| | shift_msa = None
|
| | shift_mlp = None
|
| | (
|
| | scale_msa,
|
| | gate_msa,
|
| | scale_mlp,
|
| | gate_mlp,
|
| | ) = self.adaLN_modulation(
|
| | c
|
| | ).chunk(4, dim=1)
|
| | qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
| | return qkv, (
|
| | x,
|
| | gate_msa,
|
| | shift_mlp,
|
| | scale_mlp,
|
| | gate_mlp,
|
| | )
|
| | else:
|
| | if not self.scale_mod_only:
|
| | (
|
| | shift_msa,
|
| | scale_msa,
|
| | ) = self.adaLN_modulation(
|
| | c
|
| | ).chunk(2, dim=1)
|
| | else:
|
| | shift_msa = None
|
| | scale_msa = self.adaLN_modulation(c)
|
| | qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
| | return qkv, None
|
| |
|
| | def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
| | assert not self.pre_only
|
| | x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
| | x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
| | modulate(self.norm2(x), shift_mlp, scale_mlp)
|
| | )
|
| | return x
|
| |
|
| | def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| | assert not self.pre_only
|
| | qkv, intermediates = self.pre_attention(x, c)
|
| | attn = optimized_attention(
|
| | qkv,
|
| | num_heads=self.attn.num_heads,
|
| | )
|
| | return self.post_attention(attn, *intermediates)
|
| |
|
| |
|
| | def block_mixing(*args, use_checkpoint=True, **kwargs):
|
| | if use_checkpoint:
|
| | return torch.utils.checkpoint.checkpoint(
|
| | _block_mixing, *args, use_reentrant=False, **kwargs
|
| | )
|
| | else:
|
| | return _block_mixing(*args, **kwargs)
|
| |
|
| |
|
| | def _block_mixing(context, x, context_block, x_block, c):
|
| | context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
| |
|
| | x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
| |
|
| | o = []
|
| | for t in range(3):
|
| | o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
|
| | qkv = tuple(o)
|
| |
|
| | attn = optimized_attention(
|
| | qkv,
|
| | num_heads=x_block.attn.num_heads,
|
| | )
|
| | context_attn, x_attn = (
|
| | attn[:, : context_qkv[0].shape[1]],
|
| | attn[:, context_qkv[0].shape[1] :],
|
| | )
|
| |
|
| | if not context_block.pre_only:
|
| | context = context_block.post_attention(context_attn, *context_intermediates)
|
| |
|
| | else:
|
| | context = None
|
| | x = x_block.post_attention(x_attn, *x_intermediates)
|
| | return context, x
|
| |
|
| |
|
| | class JointBlock(nn.Module):
|
| | """just a small wrapper to serve as a fsdp unit"""
|
| |
|
| | def __init__(
|
| | self,
|
| | *args,
|
| | **kwargs,
|
| | ):
|
| | super().__init__()
|
| | pre_only = kwargs.pop("pre_only")
|
| | qk_norm = kwargs.pop("qk_norm", None)
|
| | self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
| | self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
| |
|
| | def forward(self, *args, **kwargs):
|
| | return block_mixing(
|
| | *args, context_block=self.context_block, x_block=self.x_block, **kwargs
|
| | )
|
| |
|
| |
|
| | class FinalLayer(nn.Module):
|
| | """
|
| | The final layer of DiT.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | hidden_size: int,
|
| | patch_size: int,
|
| | out_channels: int,
|
| | total_out_channels: Optional[int] = None,
|
| | dtype=None,
|
| | device=None,
|
| | operations=None,
|
| | ):
|
| | super().__init__()
|
| | self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| | self.linear = (
|
| | operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
| | if (total_out_channels is None)
|
| | else operations.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
|
| | )
|
| | self.adaLN_modulation = nn.Sequential(
|
| | nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
| | )
|
| |
|
| | def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
| | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| | x = modulate(self.norm_final(x), shift, scale)
|
| | x = self.linear(x)
|
| | return x
|
| |
|
| | class SelfAttentionContext(nn.Module):
|
| | def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operations=None):
|
| | super().__init__()
|
| | dim_head = dim // heads
|
| | inner_dim = dim
|
| |
|
| | self.heads = heads
|
| | self.dim_head = dim_head
|
| |
|
| | self.qkv = operations.Linear(dim, dim * 3, bias=True, dtype=dtype, device=device)
|
| |
|
| | self.proj = operations.Linear(inner_dim, dim, dtype=dtype, device=device)
|
| |
|
| | def forward(self, x):
|
| | qkv = self.qkv(x)
|
| | q, k, v = split_qkv(qkv, self.dim_head)
|
| | x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
|
| | return self.proj(x)
|
| |
|
| | class ContextProcessorBlock(nn.Module):
|
| | def __init__(self, context_size, dtype=None, device=None, operations=None):
|
| | super().__init__()
|
| | self.norm1 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| | self.attn = SelfAttentionContext(context_size, dtype=dtype, device=device, operations=operations)
|
| | self.norm2 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| | self.mlp = Mlp(in_features=context_size, hidden_features=(context_size * 4), act_layer=lambda: nn.GELU(approximate="tanh"), drop=0, dtype=dtype, device=device, operations=operations)
|
| |
|
| | def forward(self, x):
|
| | x += self.attn(self.norm1(x))
|
| | x += self.mlp(self.norm2(x))
|
| | return x
|
| |
|
| | class ContextProcessor(nn.Module):
|
| | def __init__(self, context_size, num_layers, dtype=None, device=None, operations=None):
|
| | super().__init__()
|
| | self.layers = torch.nn.ModuleList([ContextProcessorBlock(context_size, dtype=dtype, device=device, operations=operations) for i in range(num_layers)])
|
| | self.norm = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
| |
|
| | def forward(self, x):
|
| | for i, l in enumerate(self.layers):
|
| | x = l(x)
|
| | return self.norm(x)
|
| |
|
| | class MMDiT(nn.Module):
|
| | """
|
| | Diffusion model with a Transformer backbone.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | input_size: int = 32,
|
| | patch_size: int = 2,
|
| | in_channels: int = 4,
|
| | depth: int = 28,
|
| |
|
| |
|
| | mlp_ratio: float = 4.0,
|
| | learn_sigma: bool = False,
|
| | adm_in_channels: Optional[int] = None,
|
| | context_embedder_config: Optional[Dict] = None,
|
| | compile_core: bool = False,
|
| | use_checkpoint: bool = False,
|
| | register_length: int = 0,
|
| | attn_mode: str = "torch",
|
| | rmsnorm: bool = False,
|
| | scale_mod_only: bool = False,
|
| | swiglu: bool = False,
|
| | out_channels: Optional[int] = None,
|
| | pos_embed_scaling_factor: Optional[float] = None,
|
| | pos_embed_offset: Optional[float] = None,
|
| | pos_embed_max_size: Optional[int] = None,
|
| | num_patches = None,
|
| | qk_norm: Optional[str] = None,
|
| | qkv_bias: bool = True,
|
| | context_processor_layers = None,
|
| | context_size = 4096,
|
| | num_blocks = None,
|
| | final_layer = True,
|
| | dtype = None,
|
| | device = None,
|
| | operations = None,
|
| | ):
|
| | super().__init__()
|
| | self.dtype = dtype
|
| | self.learn_sigma = learn_sigma
|
| | self.in_channels = in_channels
|
| | default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
| | self.out_channels = default(out_channels, default_out_channels)
|
| | self.patch_size = patch_size
|
| | self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
| | self.pos_embed_offset = pos_embed_offset
|
| | self.pos_embed_max_size = pos_embed_max_size
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | self.hidden_size = 64 * depth
|
| | num_heads = depth
|
| | if num_blocks is None:
|
| | num_blocks = depth
|
| |
|
| | self.depth = depth
|
| | self.num_heads = num_heads
|
| |
|
| | self.x_embedder = PatchEmbed(
|
| | input_size,
|
| | patch_size,
|
| | in_channels,
|
| | self.hidden_size,
|
| | bias=True,
|
| | strict_img_size=self.pos_embed_max_size is None,
|
| | dtype=dtype,
|
| | device=device,
|
| | operations=operations
|
| | )
|
| | self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
| |
|
| | self.y_embedder = None
|
| | if adm_in_channels is not None:
|
| | assert isinstance(adm_in_channels, int)
|
| | self.y_embedder = VectorEmbedder(adm_in_channels, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
| |
|
| | if context_processor_layers is not None:
|
| | self.context_processor = ContextProcessor(context_size, context_processor_layers, dtype=dtype, device=device, operations=operations)
|
| | else:
|
| | self.context_processor = None
|
| |
|
| | self.context_embedder = nn.Identity()
|
| | if context_embedder_config is not None:
|
| | if context_embedder_config["target"] == "torch.nn.Linear":
|
| | self.context_embedder = operations.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
|
| |
|
| | self.register_length = register_length
|
| | if self.register_length > 0:
|
| | self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size, dtype=dtype, device=device))
|
| |
|
| |
|
| |
|
| |
|
| | if num_patches is not None:
|
| | self.register_buffer(
|
| | "pos_embed",
|
| | torch.empty(1, num_patches, self.hidden_size, dtype=dtype, device=device),
|
| | )
|
| | else:
|
| | self.pos_embed = None
|
| |
|
| | self.use_checkpoint = use_checkpoint
|
| | self.joint_blocks = nn.ModuleList(
|
| | [
|
| | JointBlock(
|
| | self.hidden_size,
|
| | num_heads,
|
| | mlp_ratio=mlp_ratio,
|
| | qkv_bias=qkv_bias,
|
| | attn_mode=attn_mode,
|
| | pre_only=(i == num_blocks - 1) and final_layer,
|
| | rmsnorm=rmsnorm,
|
| | scale_mod_only=scale_mod_only,
|
| | swiglu=swiglu,
|
| | qk_norm=qk_norm,
|
| | dtype=dtype,
|
| | device=device,
|
| | operations=operations
|
| | )
|
| | for i in range(num_blocks)
|
| | ]
|
| | )
|
| |
|
| | if final_layer:
|
| | self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
| |
|
| | if compile_core:
|
| | assert False
|
| | self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
|
| |
|
| | def cropped_pos_embed(self, hw, device=None):
|
| | p = self.x_embedder.patch_size[0]
|
| | h, w = hw
|
| |
|
| | h = (h + 1) // p
|
| | w = (w + 1) // p
|
| | if self.pos_embed is None:
|
| | return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
|
| | assert self.pos_embed_max_size is not None
|
| | assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
| | assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
| | top = (self.pos_embed_max_size - h) // 2
|
| | left = (self.pos_embed_max_size - w) // 2
|
| | spatial_pos_embed = rearrange(
|
| | self.pos_embed,
|
| | "1 (h w) c -> 1 h w c",
|
| | h=self.pos_embed_max_size,
|
| | w=self.pos_embed_max_size,
|
| | )
|
| | spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
| | spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | return spatial_pos_embed
|
| |
|
| | def unpatchify(self, x, hw=None):
|
| | """
|
| | x: (N, T, patch_size**2 * C)
|
| | imgs: (N, H, W, C)
|
| | """
|
| | c = self.out_channels
|
| | p = self.x_embedder.patch_size[0]
|
| | if hw is None:
|
| | h = w = int(x.shape[1] ** 0.5)
|
| | else:
|
| | h, w = hw
|
| | h = (h + 1) // p
|
| | w = (w + 1) // p
|
| | assert h * w == x.shape[1]
|
| |
|
| | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| | x = torch.einsum("nhwpqc->nchpwq", x)
|
| | imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
| | return imgs
|
| |
|
| | def forward_core_with_concat(
|
| | self,
|
| | x: torch.Tensor,
|
| | c_mod: torch.Tensor,
|
| | context: Optional[torch.Tensor] = None,
|
| | control = None,
|
| | ) -> torch.Tensor:
|
| | if self.register_length > 0:
|
| | context = torch.cat(
|
| | (
|
| | repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
|
| | default(context, torch.Tensor([]).type_as(x)),
|
| | ),
|
| | 1,
|
| | )
|
| |
|
| |
|
| |
|
| | blocks = len(self.joint_blocks)
|
| | for i in range(blocks):
|
| | context, x = self.joint_blocks[i](
|
| | context,
|
| | x,
|
| | c=c_mod,
|
| | use_checkpoint=self.use_checkpoint,
|
| | )
|
| | if control is not None:
|
| | control_o = control.get("output")
|
| | if i < len(control_o):
|
| | add = control_o[i]
|
| | if add is not None:
|
| | x += add
|
| |
|
| | x = self.final_layer(x, c_mod)
|
| | return x
|
| |
|
| | def forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | t: torch.Tensor,
|
| | y: Optional[torch.Tensor] = None,
|
| | context: Optional[torch.Tensor] = None,
|
| | control = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass of DiT.
|
| | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| | t: (N,) tensor of diffusion timesteps
|
| | y: (N,) tensor of class labels
|
| | """
|
| |
|
| | if self.context_processor is not None:
|
| | context = self.context_processor(context)
|
| |
|
| | hw = x.shape[-2:]
|
| | x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x)
|
| | c = self.t_embedder(t, dtype=x.dtype)
|
| | if y is not None and self.y_embedder is not None:
|
| | y = self.y_embedder(y)
|
| | c = c + y
|
| |
|
| | if context is not None:
|
| | context = self.context_embedder(context)
|
| |
|
| | x = self.forward_core_with_concat(x, c, context, control)
|
| |
|
| | x = self.unpatchify(x, hw=hw)
|
| | return x[:,:,:hw[-2],:hw[-1]]
|
| |
|
| |
|
| | class OpenAISignatureMMDITWrapper(MMDiT):
|
| | def forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | timesteps: torch.Tensor,
|
| | context: Optional[torch.Tensor] = None,
|
| | y: Optional[torch.Tensor] = None,
|
| | control = None,
|
| | **kwargs,
|
| | ) -> torch.Tensor:
|
| | return super().forward(x, timesteps, context=context, y=y, control=control)
|
| |
|
| |
|