|
|
| from __future__ import annotations
|
|
|
| from typing import List, Optional, Tuple
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import comfy.ldm.common_dit
|
|
|
| from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
|
| from comfy.ldm.modules.attention import optimized_attention_masked
|
| from comfy.ldm.flux.layers import EmbedND
|
|
|
|
|
| def modulate(x, scale):
|
| return x * (1 + scale.unsqueeze(1))
|
|
|
|
|
|
|
|
|
|
|
|
|
| class JointAttention(nn.Module):
|
| """Multi-head attention module."""
|
|
|
| def __init__(
|
| self,
|
| dim: int,
|
| n_heads: int,
|
| n_kv_heads: Optional[int],
|
| qk_norm: bool,
|
| operation_settings={},
|
| ):
|
| """
|
| Initialize the Attention module.
|
|
|
| Args:
|
| dim (int): Number of input dimensions.
|
| n_heads (int): Number of heads.
|
| n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
|
|
|
| """
|
| super().__init__()
|
| self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
| self.n_local_heads = n_heads
|
| self.n_local_kv_heads = self.n_kv_heads
|
| self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
| self.head_dim = dim // n_heads
|
|
|
| self.qkv = operation_settings.get("operations").Linear(
|
| dim,
|
| (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
|
| bias=False,
|
| device=operation_settings.get("device"),
|
| dtype=operation_settings.get("dtype"),
|
| )
|
| self.out = operation_settings.get("operations").Linear(
|
| n_heads * self.head_dim,
|
| dim,
|
| bias=False,
|
| device=operation_settings.get("device"),
|
| dtype=operation_settings.get("dtype"),
|
| )
|
|
|
| if qk_norm:
|
| self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
| self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
|
| else:
|
| self.q_norm = self.k_norm = nn.Identity()
|
|
|
| @staticmethod
|
| def apply_rotary_emb(
|
| x_in: torch.Tensor,
|
| freqs_cis: torch.Tensor,
|
| ) -> torch.Tensor:
|
| """
|
| Apply rotary embeddings to input tensors using the given frequency
|
| tensor.
|
|
|
| This function applies rotary embeddings to the given query 'xq' and
|
| key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
| input tensors are reshaped as complex numbers, and the frequency tensor
|
| is reshaped for broadcasting compatibility. The resulting tensors
|
| contain rotary embeddings and are returned as real tensors.
|
|
|
| Args:
|
| x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
| freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
| exponentials.
|
|
|
| Returns:
|
| Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
| and key tensor with rotary embeddings.
|
| """
|
|
|
| t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
|
| t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
| return t_out.reshape(*x_in.shape)
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| x_mask: torch.Tensor,
|
| freqs_cis: torch.Tensor,
|
| ) -> torch.Tensor:
|
| """
|
|
|
| Args:
|
| x:
|
| x_mask:
|
| freqs_cis:
|
|
|
| Returns:
|
|
|
| """
|
| bsz, seqlen, _ = x.shape
|
|
|
| xq, xk, xv = torch.split(
|
| self.qkv(x),
|
| [
|
| self.n_local_heads * self.head_dim,
|
| self.n_local_kv_heads * self.head_dim,
|
| self.n_local_kv_heads * self.head_dim,
|
| ],
|
| dim=-1,
|
| )
|
| xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
| xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
| xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
|
|
| xq = self.q_norm(xq)
|
| xk = self.k_norm(xk)
|
|
|
| xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
| xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
|
|
| n_rep = self.n_local_heads // self.n_local_kv_heads
|
| if n_rep >= 1:
|
| xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
| xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
| output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
|
|
|
| return self.out(output)
|
|
|
|
|
| class FeedForward(nn.Module):
|
| def __init__(
|
| self,
|
| dim: int,
|
| hidden_dim: int,
|
| multiple_of: int,
|
| ffn_dim_multiplier: Optional[float],
|
| operation_settings={},
|
| ):
|
| """
|
| Initialize the FeedForward module.
|
|
|
| Args:
|
| dim (int): Input dimension.
|
| hidden_dim (int): Hidden dimension of the feedforward layer.
|
| multiple_of (int): Value to ensure hidden dimension is a multiple
|
| of this value.
|
| ffn_dim_multiplier (float, optional): Custom multiplier for hidden
|
| dimension. Defaults to None.
|
|
|
| """
|
| super().__init__()
|
|
|
| if ffn_dim_multiplier is not None:
|
| hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
|
|
| self.w1 = operation_settings.get("operations").Linear(
|
| dim,
|
| hidden_dim,
|
| bias=False,
|
| device=operation_settings.get("device"),
|
| dtype=operation_settings.get("dtype"),
|
| )
|
| self.w2 = operation_settings.get("operations").Linear(
|
| hidden_dim,
|
| dim,
|
| bias=False,
|
| device=operation_settings.get("device"),
|
| dtype=operation_settings.get("dtype"),
|
| )
|
| self.w3 = operation_settings.get("operations").Linear(
|
| dim,
|
| hidden_dim,
|
| bias=False,
|
| device=operation_settings.get("device"),
|
| dtype=operation_settings.get("dtype"),
|
| )
|
|
|
|
|
| def _forward_silu_gating(self, x1, x3):
|
| return F.silu(x1) * x3
|
|
|
| def forward(self, x):
|
| return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
|
|
|
|
| class JointTransformerBlock(nn.Module):
|
| def __init__(
|
| self,
|
| layer_id: int,
|
| dim: int,
|
| n_heads: int,
|
| n_kv_heads: int,
|
| multiple_of: int,
|
| ffn_dim_multiplier: float,
|
| norm_eps: float,
|
| qk_norm: bool,
|
| modulation=True,
|
| operation_settings={},
|
| ) -> None:
|
| """
|
| Initialize a TransformerBlock.
|
|
|
| Args:
|
| layer_id (int): Identifier for the layer.
|
| dim (int): Embedding dimension of the input features.
|
| n_heads (int): Number of attention heads.
|
| n_kv_heads (Optional[int]): Number of attention heads in key and
|
| value features (if using GQA), or set to None for the same as
|
| query.
|
| multiple_of (int):
|
| ffn_dim_multiplier (float):
|
| norm_eps (float):
|
|
|
| """
|
| super().__init__()
|
| self.dim = dim
|
| self.head_dim = dim // n_heads
|
| self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
|
| self.feed_forward = FeedForward(
|
| dim=dim,
|
| hidden_dim=4 * dim,
|
| multiple_of=multiple_of,
|
| ffn_dim_multiplier=ffn_dim_multiplier,
|
| operation_settings=operation_settings,
|
| )
|
| self.layer_id = layer_id
|
| self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
| self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
|
|
| self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
| self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
|
|
| self.modulation = modulation
|
| if modulation:
|
| self.adaLN_modulation = nn.Sequential(
|
| nn.SiLU(),
|
| operation_settings.get("operations").Linear(
|
| min(dim, 1024),
|
| 4 * dim,
|
| bias=True,
|
| device=operation_settings.get("device"),
|
| dtype=operation_settings.get("dtype"),
|
| ),
|
| )
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| x_mask: torch.Tensor,
|
| freqs_cis: torch.Tensor,
|
| adaln_input: Optional[torch.Tensor]=None,
|
| ):
|
| """
|
| Perform a forward pass through the TransformerBlock.
|
|
|
| Args:
|
| x (torch.Tensor): Input tensor.
|
| freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
|
|
| Returns:
|
| torch.Tensor: Output tensor after applying attention and
|
| feedforward layers.
|
|
|
| """
|
| if self.modulation:
|
| assert adaln_input is not None
|
| scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
|
|
| x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
| self.attention(
|
| modulate(self.attention_norm1(x), scale_msa),
|
| x_mask,
|
| freqs_cis,
|
| )
|
| )
|
| x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
| self.feed_forward(
|
| modulate(self.ffn_norm1(x), scale_mlp),
|
| )
|
| )
|
| else:
|
| assert adaln_input is None
|
| x = x + self.attention_norm2(
|
| self.attention(
|
| self.attention_norm1(x),
|
| x_mask,
|
| freqs_cis,
|
| )
|
| )
|
| x = x + self.ffn_norm2(
|
| self.feed_forward(
|
| self.ffn_norm1(x),
|
| )
|
| )
|
| return x
|
|
|
|
|
| class FinalLayer(nn.Module):
|
| """
|
| The final layer of NextDiT.
|
| """
|
|
|
| def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
|
| super().__init__()
|
| self.norm_final = operation_settings.get("operations").LayerNorm(
|
| hidden_size,
|
| elementwise_affine=False,
|
| eps=1e-6,
|
| device=operation_settings.get("device"),
|
| dtype=operation_settings.get("dtype"),
|
| )
|
| self.linear = operation_settings.get("operations").Linear(
|
| hidden_size,
|
| patch_size * patch_size * out_channels,
|
| bias=True,
|
| device=operation_settings.get("device"),
|
| dtype=operation_settings.get("dtype"),
|
| )
|
|
|
| self.adaLN_modulation = nn.Sequential(
|
| nn.SiLU(),
|
| operation_settings.get("operations").Linear(
|
| min(hidden_size, 1024),
|
| hidden_size,
|
| bias=True,
|
| device=operation_settings.get("device"),
|
| dtype=operation_settings.get("dtype"),
|
| ),
|
| )
|
|
|
| def forward(self, x, c):
|
| scale = self.adaLN_modulation(c)
|
| x = modulate(self.norm_final(x), scale)
|
| x = self.linear(x)
|
| return x
|
|
|
|
|
| class NextDiT(nn.Module):
|
| """
|
| Diffusion model with a Transformer backbone.
|
| """
|
|
|
| def __init__(
|
| self,
|
| patch_size: int = 2,
|
| in_channels: int = 4,
|
| dim: int = 4096,
|
| n_layers: int = 32,
|
| n_refiner_layers: int = 2,
|
| n_heads: int = 32,
|
| n_kv_heads: Optional[int] = None,
|
| multiple_of: int = 256,
|
| ffn_dim_multiplier: Optional[float] = None,
|
| norm_eps: float = 1e-5,
|
| qk_norm: bool = False,
|
| cap_feat_dim: int = 5120,
|
| axes_dims: List[int] = (16, 56, 56),
|
| axes_lens: List[int] = (1, 512, 512),
|
| image_model=None,
|
| device=None,
|
| dtype=None,
|
| operations=None,
|
| ) -> None:
|
| super().__init__()
|
| self.dtype = dtype
|
| operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
| self.in_channels = in_channels
|
| self.out_channels = in_channels
|
| self.patch_size = patch_size
|
|
|
| self.x_embedder = operation_settings.get("operations").Linear(
|
| in_features=patch_size * patch_size * in_channels,
|
| out_features=dim,
|
| bias=True,
|
| device=operation_settings.get("device"),
|
| dtype=operation_settings.get("dtype"),
|
| )
|
|
|
| self.noise_refiner = nn.ModuleList(
|
| [
|
| JointTransformerBlock(
|
| layer_id,
|
| dim,
|
| n_heads,
|
| n_kv_heads,
|
| multiple_of,
|
| ffn_dim_multiplier,
|
| norm_eps,
|
| qk_norm,
|
| modulation=True,
|
| operation_settings=operation_settings,
|
| )
|
| for layer_id in range(n_refiner_layers)
|
| ]
|
| )
|
| self.context_refiner = nn.ModuleList(
|
| [
|
| JointTransformerBlock(
|
| layer_id,
|
| dim,
|
| n_heads,
|
| n_kv_heads,
|
| multiple_of,
|
| ffn_dim_multiplier,
|
| norm_eps,
|
| qk_norm,
|
| modulation=False,
|
| operation_settings=operation_settings,
|
| )
|
| for layer_id in range(n_refiner_layers)
|
| ]
|
| )
|
|
|
| self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
| self.cap_embedder = nn.Sequential(
|
| RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
|
| operation_settings.get("operations").Linear(
|
| cap_feat_dim,
|
| dim,
|
| bias=True,
|
| device=operation_settings.get("device"),
|
| dtype=operation_settings.get("dtype"),
|
| ),
|
| )
|
|
|
| self.layers = nn.ModuleList(
|
| [
|
| JointTransformerBlock(
|
| layer_id,
|
| dim,
|
| n_heads,
|
| n_kv_heads,
|
| multiple_of,
|
| ffn_dim_multiplier,
|
| norm_eps,
|
| qk_norm,
|
| operation_settings=operation_settings,
|
| )
|
| for layer_id in range(n_layers)
|
| ]
|
| )
|
| self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
|
| self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
|
|
| assert (dim // n_heads) == sum(axes_dims)
|
| self.axes_dims = axes_dims
|
| self.axes_lens = axes_lens
|
| self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
| self.dim = dim
|
| self.n_heads = n_heads
|
|
|
| def unpatchify(
|
| self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
|
| ) -> List[torch.Tensor]:
|
| """
|
| x: (N, T, patch_size**2 * C)
|
| imgs: (N, H, W, C)
|
| """
|
| pH = pW = self.patch_size
|
| imgs = []
|
| for i in range(x.size(0)):
|
| H, W = img_size[i]
|
| begin = cap_size[i]
|
| end = begin + (H // pH) * (W // pW)
|
| imgs.append(
|
| x[i][begin:end]
|
| .view(H // pH, W // pW, pH, pW, self.out_channels)
|
| .permute(4, 0, 2, 1, 3)
|
| .flatten(3, 4)
|
| .flatten(1, 2)
|
| )
|
|
|
| if return_tensor:
|
| imgs = torch.stack(imgs, dim=0)
|
| return imgs
|
|
|
| def patchify_and_embed(
|
| self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
|
| ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
| bsz = len(x)
|
| pH = pW = self.patch_size
|
| device = x[0].device
|
| dtype = x[0].dtype
|
|
|
| if cap_mask is not None:
|
| l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
| else:
|
| l_effective_cap_len = [num_tokens] * bsz
|
|
|
| if cap_mask is not None and not torch.is_floating_point(cap_mask):
|
| cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
|
|
|
| img_sizes = [(img.size(1), img.size(2)) for img in x]
|
| l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
|
|
|
| max_seq_len = max(
|
| (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
|
| )
|
| max_cap_len = max(l_effective_cap_len)
|
| max_img_len = max(l_effective_img_len)
|
|
|
| position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
|
|
| for i in range(bsz):
|
| cap_len = l_effective_cap_len[i]
|
| img_len = l_effective_img_len[i]
|
| H, W = img_sizes[i]
|
| H_tokens, W_tokens = H // pH, W // pW
|
| assert H_tokens * W_tokens == img_len
|
|
|
| position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
| position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
| row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
| col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
| position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
| position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
|
|
| freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
|
|
|
|
|
| cap_freqs_cis_shape = list(freqs_cis.shape)
|
|
|
| cap_freqs_cis_shape[1] = cap_feats.shape[1]
|
| cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
|
|
| img_freqs_cis_shape = list(freqs_cis.shape)
|
| img_freqs_cis_shape[1] = max_img_len
|
| img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
|
|
| for i in range(bsz):
|
| cap_len = l_effective_cap_len[i]
|
| img_len = l_effective_img_len[i]
|
| cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
| img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
|
|
|
|
|
| for layer in self.context_refiner:
|
| cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
|
|
|
|
| flat_x = []
|
| for i in range(bsz):
|
| img = x[i]
|
| C, H, W = img.size()
|
| img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
| flat_x.append(img)
|
| x = flat_x
|
| padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
|
| padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
|
| for i in range(bsz):
|
| padded_img_embed[i, :l_effective_img_len[i]] = x[i]
|
| padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
|
|
|
| padded_img_embed = self.x_embedder(padded_img_embed)
|
| padded_img_mask = padded_img_mask.unsqueeze(1)
|
| for layer in self.noise_refiner:
|
| padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
|
|
|
| if cap_mask is not None:
|
| mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
| mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
|
| else:
|
| mask = None
|
|
|
| padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
|
| for i in range(bsz):
|
| cap_len = l_effective_cap_len[i]
|
| img_len = l_effective_img_len[i]
|
|
|
| padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
|
| padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
|
|
|
| return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
|
|
|
|
| def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
| t = 1.0 - timesteps
|
| cap_feats = context
|
| cap_mask = attention_mask
|
| bs, c, h, w = x.shape
|
| x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
| """
|
| Forward pass of NextDiT.
|
| t: (N,) tensor of diffusion timesteps
|
| y: (N,) tensor of text tokens/features
|
| """
|
|
|
| t = self.t_embedder(t, dtype=x.dtype)
|
| adaln_input = t
|
|
|
| cap_feats = self.cap_embedder(cap_feats)
|
|
|
| x_is_tensor = isinstance(x, torch.Tensor)
|
| x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
|
| freqs_cis = freqs_cis.to(x.device)
|
|
|
| for layer in self.layers:
|
| x = layer(x, mask, freqs_cis, adaln_input)
|
|
|
| x = self.final_layer(x, adaln_input)
|
| x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
|
|
| return -x
|
|
|
|
|