| | import math |
| | from typing import List, Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor |
| |
|
| | from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch |
| |
|
| |
|
| | class ControlNetEmbedder(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | img_size: int, |
| | patch_size: int, |
| | in_chans: int, |
| | attention_head_dim: int, |
| | num_attention_heads: int, |
| | adm_in_channels: int, |
| | num_layers: int, |
| | main_model_double: int, |
| | double_y_emb: bool, |
| | device: torch.device, |
| | dtype: torch.dtype, |
| | pos_embed_max_size: Optional[int] = None, |
| | operations = None, |
| | ): |
| | super().__init__() |
| | self.main_model_double = main_model_double |
| | self.dtype = dtype |
| | self.hidden_size = num_attention_heads * attention_head_dim |
| | self.patch_size = patch_size |
| | self.x_embedder = PatchEmbed( |
| | img_size=img_size, |
| | patch_size=patch_size, |
| | in_chans=in_chans, |
| | embed_dim=self.hidden_size, |
| | strict_img_size=pos_embed_max_size is None, |
| | device=device, |
| | dtype=dtype, |
| | operations=operations, |
| | ) |
| |
|
| | self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations) |
| |
|
| | self.double_y_emb = double_y_emb |
| | if self.double_y_emb: |
| | self.orig_y_embedder = VectorEmbedder( |
| | adm_in_channels, self.hidden_size, dtype, device, operations=operations |
| | ) |
| | self.y_embedder = VectorEmbedder( |
| | self.hidden_size, self.hidden_size, dtype, device, operations=operations |
| | ) |
| | else: |
| | self.y_embedder = VectorEmbedder( |
| | adm_in_channels, self.hidden_size, dtype, device, operations=operations |
| | ) |
| |
|
| | self.transformer_blocks = nn.ModuleList( |
| | DismantledBlock( |
| | hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True, |
| | dtype=dtype, device=device, operations=operations |
| | ) |
| | for _ in range(num_layers) |
| | ) |
| |
|
| | |
| | |
| | self.use_y_embedder = True |
| |
|
| | self.controlnet_blocks = nn.ModuleList([]) |
| | for _ in range(len(self.transformer_blocks)): |
| | controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) |
| | self.controlnet_blocks.append(controlnet_block) |
| |
|
| | self.pos_embed_input = PatchEmbed( |
| | img_size=img_size, |
| | patch_size=patch_size, |
| | in_chans=in_chans, |
| | embed_dim=self.hidden_size, |
| | strict_img_size=False, |
| | device=device, |
| | dtype=dtype, |
| | operations=operations, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | timesteps: torch.Tensor, |
| | y: Optional[torch.Tensor] = None, |
| | context: Optional[torch.Tensor] = None, |
| | hint = None, |
| | ) -> Tuple[Tensor, List[Tensor]]: |
| | x_shape = list(x.shape) |
| | x = self.x_embedder(x) |
| | if not self.double_y_emb: |
| | h = (x_shape[-2] + 1) // self.patch_size |
| | w = (x_shape[-1] + 1) // self.patch_size |
| | x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device) |
| | c = self.t_embedder(timesteps, dtype=x.dtype) |
| | if y is not None and self.y_embedder is not None: |
| | if self.double_y_emb: |
| | y = self.orig_y_embedder(y) |
| | y = self.y_embedder(y) |
| | c = c + y |
| |
|
| | x = x + self.pos_embed_input(hint) |
| |
|
| | block_out = () |
| |
|
| | repeat = math.ceil(self.main_model_double / len(self.transformer_blocks)) |
| | for i in range(len(self.transformer_blocks)): |
| | out = self.transformer_blocks[i](x, c) |
| | if not self.double_y_emb: |
| | x = out |
| | block_out += (self.controlnet_blocks[i](out),) * repeat |
| |
|
| | return {"output": block_out} |
| |
|