|
|
| import torch
|
| import torch.nn as nn
|
|
|
|
|
| from comfy.ldm.modules.diffusionmodules.mmdit import (
|
| TimestepEmbedder,
|
| PatchEmbed,
|
| )
|
| from .poolers import AttentionPool
|
|
|
| import comfy.latent_formats
|
| from .models import HunYuanDiTBlock, calc_rope
|
|
|
|
|
|
|
| class HunYuanControlNet(nn.Module):
|
| """
|
| HunYuanDiT: Diffusion model with a Transformer backbone.
|
|
|
| Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
|
|
| Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
|
|
|
| Parameters
|
| ----------
|
| args: argparse.Namespace
|
| The arguments parsed by argparse.
|
| input_size: tuple
|
| The size of the input image.
|
| patch_size: int
|
| The size of the patch.
|
| in_channels: int
|
| The number of input channels.
|
| hidden_size: int
|
| The hidden size of the transformer backbone.
|
| depth: int
|
| The number of transformer blocks.
|
| num_heads: int
|
| The number of attention heads.
|
| mlp_ratio: float
|
| The ratio of the hidden size of the MLP in the transformer block.
|
| log_fn: callable
|
| The logging function.
|
| """
|
|
|
| def __init__(
|
| self,
|
| input_size: tuple = 128,
|
| patch_size: int = 2,
|
| in_channels: int = 4,
|
| hidden_size: int = 1408,
|
| depth: int = 40,
|
| num_heads: int = 16,
|
| mlp_ratio: float = 4.3637,
|
| text_states_dim=1024,
|
| text_states_dim_t5=2048,
|
| text_len=77,
|
| text_len_t5=256,
|
| qk_norm=True,
|
| size_cond=False,
|
| use_style_cond=False,
|
| learn_sigma=True,
|
| norm="layer",
|
| log_fn: callable = print,
|
| attn_precision=None,
|
| dtype=None,
|
| device=None,
|
| operations=None,
|
| **kwargs,
|
| ):
|
| super().__init__()
|
| self.log_fn = log_fn
|
| self.depth = depth
|
| self.learn_sigma = learn_sigma
|
| self.in_channels = in_channels
|
| self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| self.patch_size = patch_size
|
| self.num_heads = num_heads
|
| self.hidden_size = hidden_size
|
| self.text_states_dim = text_states_dim
|
| self.text_states_dim_t5 = text_states_dim_t5
|
| self.text_len = text_len
|
| self.text_len_t5 = text_len_t5
|
| self.size_cond = size_cond
|
| self.use_style_cond = use_style_cond
|
| self.norm = norm
|
| self.dtype = dtype
|
| self.latent_format = comfy.latent_formats.SDXL
|
|
|
| self.mlp_t5 = nn.Sequential(
|
| nn.Linear(
|
| self.text_states_dim_t5,
|
| self.text_states_dim_t5 * 4,
|
| bias=True,
|
| dtype=dtype,
|
| device=device,
|
| ),
|
| nn.SiLU(),
|
| nn.Linear(
|
| self.text_states_dim_t5 * 4,
|
| self.text_states_dim,
|
| bias=True,
|
| dtype=dtype,
|
| device=device,
|
| ),
|
| )
|
|
|
| self.text_embedding_padding = nn.Parameter(
|
| torch.randn(
|
| self.text_len + self.text_len_t5,
|
| self.text_states_dim,
|
| dtype=dtype,
|
| device=device,
|
| )
|
| )
|
|
|
|
|
| pooler_out_dim = 1024
|
| self.pooler = AttentionPool(
|
| self.text_len_t5,
|
| self.text_states_dim_t5,
|
| num_heads=8,
|
| output_dim=pooler_out_dim,
|
| dtype=dtype,
|
| device=device,
|
| operations=operations,
|
| )
|
|
|
|
|
| self.extra_in_dim = pooler_out_dim
|
|
|
| if self.size_cond:
|
|
|
| self.extra_in_dim += 6 * 256
|
|
|
| if self.use_style_cond:
|
|
|
| self.style_embedder = nn.Embedding(
|
| 1, hidden_size, dtype=dtype, device=device
|
| )
|
| self.extra_in_dim += hidden_size
|
|
|
|
|
| self.x_embedder = PatchEmbed(
|
| input_size,
|
| patch_size,
|
| in_channels,
|
| hidden_size,
|
| dtype=dtype,
|
| device=device,
|
| operations=operations,
|
| )
|
| self.t_embedder = TimestepEmbedder(
|
| hidden_size, dtype=dtype, device=device, operations=operations
|
| )
|
| self.extra_embedder = nn.Sequential(
|
| operations.Linear(
|
| self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
|
| ),
|
| nn.SiLU(),
|
| operations.Linear(
|
| hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
|
| ),
|
| )
|
|
|
|
|
| self.blocks = nn.ModuleList(
|
| [
|
| HunYuanDiTBlock(
|
| hidden_size=hidden_size,
|
| c_emb_size=hidden_size,
|
| num_heads=num_heads,
|
| mlp_ratio=mlp_ratio,
|
| text_states_dim=self.text_states_dim,
|
| qk_norm=qk_norm,
|
| norm_type=self.norm,
|
| skip=False,
|
| attn_precision=attn_precision,
|
| dtype=dtype,
|
| device=device,
|
| operations=operations,
|
| )
|
| for _ in range(19)
|
| ]
|
| )
|
|
|
|
|
| self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
|
|
|
|
|
| self.after_proj_list = nn.ModuleList(
|
| [
|
|
|
| operations.Linear(
|
| self.hidden_size, self.hidden_size, dtype=dtype, device=device
|
| )
|
| for _ in range(len(self.blocks))
|
| ]
|
| )
|
|
|
| def forward(
|
| self,
|
| x,
|
| hint,
|
| timesteps,
|
| context,
|
| text_embedding_mask=None,
|
| encoder_hidden_states_t5=None,
|
| text_embedding_mask_t5=None,
|
| image_meta_size=None,
|
| style=None,
|
| return_dict=False,
|
| **kwarg,
|
| ):
|
| """
|
| Forward pass of the encoder.
|
|
|
| Parameters
|
| ----------
|
| x: torch.Tensor
|
| (B, D, H, W)
|
| t: torch.Tensor
|
| (B)
|
| encoder_hidden_states: torch.Tensor
|
| CLIP text embedding, (B, L_clip, D)
|
| text_embedding_mask: torch.Tensor
|
| CLIP text embedding mask, (B, L_clip)
|
| encoder_hidden_states_t5: torch.Tensor
|
| T5 text embedding, (B, L_t5, D)
|
| text_embedding_mask_t5: torch.Tensor
|
| T5 text embedding mask, (B, L_t5)
|
| image_meta_size: torch.Tensor
|
| (B, 6)
|
| style: torch.Tensor
|
| (B)
|
| cos_cis_img: torch.Tensor
|
| sin_cis_img: torch.Tensor
|
| return_dict: bool
|
| Whether to return a dictionary.
|
| """
|
| condition = hint
|
| if condition.shape[0] == 1:
|
| condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
|
|
|
| text_states = context
|
| text_states_t5 = encoder_hidden_states_t5
|
| text_states_mask = text_embedding_mask.bool()
|
| text_states_t5_mask = text_embedding_mask_t5.bool()
|
| b_t5, l_t5, c_t5 = text_states_t5.shape
|
| text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
|
|
| padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
|
|
| text_states[:, -self.text_len :] = torch.where(
|
| text_states_mask[:, -self.text_len :].unsqueeze(2),
|
| text_states[:, -self.text_len :],
|
| padding[: self.text_len],
|
| )
|
| text_states_t5[:, -self.text_len_t5 :] = torch.where(
|
| text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
|
| text_states_t5[:, -self.text_len_t5 :],
|
| padding[self.text_len :],
|
| )
|
|
|
| text_states = torch.cat([text_states, text_states_t5], dim=1)
|
|
|
|
|
|
|
|
|
|
|
| freqs_cis_img = calc_rope(
|
| x, self.patch_size, self.hidden_size // self.num_heads
|
| )
|
|
|
|
|
| t = self.t_embedder(timesteps, dtype=self.dtype)
|
| x = self.x_embedder(x)
|
|
|
|
|
|
|
| extra_vec = self.pooler(encoder_hidden_states_t5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if style is not None:
|
| style_embedding = self.style_embedder(style)
|
| extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
|
|
|
|
| c = t + self.extra_embedder(extra_vec)
|
|
|
|
|
| condition = self.x_embedder(condition)
|
|
|
|
|
| controls = []
|
| x = x + self.before_proj(condition)
|
| for layer, block in enumerate(self.blocks):
|
| x = block(x, c, text_states, freqs_cis_img)
|
| controls.append(self.after_proj_list[layer](x))
|
|
|
| return {"output": controls}
|
|
|