# Newest version: add local&global context (cross-attn), and local&global attn (self-attn) import math import torch.nn.functional as F import torch.nn as nn import torch from typing import Optional from einops import rearrange from .moe_layers import MoEBlock import numpy as np def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) return np.concatenate([emb_sin, emb_cos], axis=1) class Timesteps(nn.Module): def __init__( self, num_channels: int, downscale_freq_shift: float = 0.0, scale: int = 1, max_period: int = 10000, ): super().__init__() self.num_channels = num_channels self.downscale_freq_shift = downscale_freq_shift self.scale = scale self.max_period = max_period def forward(self, timesteps): assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" embedding_dim = self.num_channels half_dim = embedding_dim // 2 exponent = -math.log(self.max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - self.downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] emb = self.scale * emb emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__( self, hidden_size, frequency_embedding_size=256, cond_proj_dim=None, out_size=None, ): super().__init__() if out_size is None: out_size = hidden_size self.mlp = nn.Sequential( nn.Linear(hidden_size, frequency_embedding_size, bias=True), nn.GELU(), nn.Linear(frequency_embedding_size, out_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size if cond_proj_dim is not None: self.cond_proj = nn.Linear( cond_proj_dim, frequency_embedding_size, bias=False ) self.time_embed = Timesteps(hidden_size) def forward(self, t, condition): t_freq = self.time_embed(t).type(self.mlp[0].weight.dtype) # t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) if condition is not None: t_freq = t_freq + self.cond_proj(condition) t = self.mlp(t_freq) t = t.unsqueeze(dim=1) return t class MLP(nn.Module): def __init__(self, *, width: int): super().__init__() self.width = width self.fc1 = nn.Linear(width, width * 4) self.fc2 = nn.Linear(width * 4, width) self.gelu = nn.GELU() def forward(self, x): return self.fc2(self.gelu(self.fc1(x))) class CrossAttention(nn.Module): def __init__( self, qdim, kdim, num_heads, qkv_bias=True, qk_norm=False, norm_layer=nn.LayerNorm, with_decoupled_ca=False, decoupled_ca_dim=16, decoupled_ca_weight=1.0, **kwargs, ): super().__init__() self.qdim = qdim self.kdim = kdim self.num_heads = num_heads assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads" self.head_dim = self.qdim // num_heads assert ( self.head_dim % 8 == 0 and self.head_dim <= 128 ), "Only support head_dim <= 128 and divisible by 8" self.scale = self.head_dim**-0.5 self.to_q = nn.Linear(qdim, qdim, bias=qkv_bias) self.to_k = nn.Linear(kdim, qdim, bias=qkv_bias) self.to_v = nn.Linear(kdim, qdim, bias=qkv_bias) # TODO: eps should be 1 / 65530 if using fp16 self.q_norm = ( norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.k_norm = ( norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.out_proj = nn.Linear(qdim, qdim, bias=True) self.with_dca = with_decoupled_ca if self.with_dca: self.kv_proj_dca = nn.Linear(kdim, 2 * qdim, bias=qkv_bias) self.k_norm_dca = ( norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.dca_dim = decoupled_ca_dim self.dca_weight = decoupled_ca_weight # zero init nn.init.zeros_(self.out_proj.weight) nn.init.zeros_(self.out_proj.bias) def forward(self, x, y): """ Parameters ---------- x: torch.Tensor (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim) y: torch.Tensor (batch, seqlen2, hidden_dim2) freqs_cis_img: torch.Tensor (batch, hidden_dim // 2), RoPE for image """ b, s1, c = x.shape # [b, s1, D] if self.with_dca: token_len = y.shape[1] context_dca = y[:, -self.dca_dim :, :] kv_dca = self.kv_proj_dca(context_dca).view( b, self.dca_dim, 2, self.num_heads, self.head_dim ) k_dca, v_dca = kv_dca.unbind(dim=2) # [b, s, h, d] k_dca = self.k_norm_dca(k_dca) y = y[:, : (token_len - self.dca_dim), :] _, s2, c = y.shape # [b, s2, 1024] q = self.to_q(x) k = self.to_k(y) v = self.to_v(y) kv = torch.cat((k, v), dim=-1) split_size = kv.shape[-1] // self.num_heads // 2 kv = kv.view(1, -1, self.num_heads, split_size * 2) k, v = torch.split(kv, split_size, dim=-1) q = q.view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d] k = k.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d] v = v.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d] q = self.q_norm(q) k = self.k_norm(k) with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=True ): q, k, v = map( lambda t: rearrange(t, "b n h d -> b h n d", h=self.num_heads), (q, k, v), ) context = ( F.scaled_dot_product_attention(q, k, v) .transpose(1, 2) .reshape(b, s1, -1) ) if self.with_dca: with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=True ): k_dca, v_dca = map( lambda t: rearrange(t, "b n h d -> b h n d", h=self.num_heads), (k_dca, v_dca), ) context_dca = ( F.scaled_dot_product_attention(q, k_dca, v_dca) .transpose(1, 2) .reshape(b, s1, -1) ) context = context + self.dca_weight * context_dca out = self.out_proj(context) # context.reshape - B, L1, -1 return out class Attention(nn.Module): """ We rename some layer names to align with flash attention """ def __init__( self, dim, num_heads, qkv_bias=True, qk_norm=False, norm_layer=nn.LayerNorm, use_global_processor=False, ): super().__init__() self.use_global_processor = use_global_processor self.dim = dim self.num_heads = num_heads assert self.dim % num_heads == 0, "dim should be divisible by num_heads" self.head_dim = self.dim // num_heads # This assertion is aligned with flash attention assert ( self.head_dim % 8 == 0 and self.head_dim <= 128 ), "Only support head_dim <= 128 and divisible by 8" self.scale = self.head_dim**-0.5 self.to_q = nn.Linear(dim, dim, bias=qkv_bias) self.to_k = nn.Linear(dim, dim, bias=qkv_bias) self.to_v = nn.Linear(dim, dim, bias=qkv_bias) # TODO: eps should be 1 / 65530 if using fp16 self.q_norm = ( norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.k_norm = ( norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() ) self.out_proj = nn.Linear(dim, dim) # set processor self.processor = LocalGlobalProcessor(use_global=use_global_processor) def forward(self, x): return self.processor(self, x) class AttentionPool(nn.Module): def __init__( self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None ): super().__init__() self.positional_embedding = nn.Parameter( torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5 ) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x, attention_mask=None): x = x.permute(1, 0, 2) # NLC -> LNC if attention_mask is not None: attention_mask = attention_mask.unsqueeze(-1).permute(1, 0, 2) global_emb = (x * attention_mask).sum(dim=0) / attention_mask.sum(dim=0) x = torch.cat([global_emb[None,], x], dim=0) else: x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC x, _ = F.multi_head_attention_forward( query=x[:1], key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat( [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] ), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False, ) return x.squeeze(0) class LocalGlobalProcessor: def __init__(self, use_global=False): self.use_global = use_global def __call__( self, attn: Attention, hidden_states: torch.Tensor, ): """ hidden_states: [B, L, C] """ if self.use_global: B_old, N_old, C_old = hidden_states.shape hidden_states = hidden_states.reshape(1, -1, C_old) B, N, C = hidden_states.shape q = attn.to_q(hidden_states) k = attn.to_k(hidden_states) v = attn.to_v(hidden_states) qkv = torch.cat((q, k, v), dim=-1) split_size = qkv.shape[-1] // attn.num_heads // 3 qkv = qkv.view(1, -1, attn.num_heads, split_size * 3) q, k, v = torch.split(qkv, split_size, dim=-1) q = q.reshape(B, N, attn.num_heads, attn.head_dim).transpose( 1, 2 ) # [b, h, s, d] k = k.reshape(B, N, attn.num_heads, attn.head_dim).transpose( 1, 2 ) # [b, h, s, d] v = v.reshape(B, N, attn.num_heads, attn.head_dim).transpose(1, 2) q = attn.q_norm(q) # [b, h, s, d] k = attn.k_norm(k) # [b, h, s, d] with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=True ): hidden_states = F.scaled_dot_product_attention(q, k, v) hidden_states = hidden_states.transpose(1, 2).reshape(B, N, -1) hidden_states = attn.out_proj(hidden_states) if self.use_global: hidden_states = hidden_states.reshape(B_old, N_old, -1) return hidden_states class PartFormerDitBlock(nn.Module): def __init__( self, hidden_size, num_heads, use_self_attention: bool = True, use_cross_attention: bool = False, use_cross_attention_2: bool = False, encoder_hidden_dim=1024, # cross-attn encoder_hidden_states dim encoder_hidden2_dim=1024, # cross-attn 2 encoder_hidden_states dim # cross_attn2_weight=0.0, qkv_bias=True, qk_norm=False, norm_layer=nn.LayerNorm, qk_norm_layer=nn.RMSNorm, with_decoupled_ca=False, decoupled_ca_dim=16, decoupled_ca_weight=1.0, skip_connection=False, timested_modulate=False, c_emb_size=0, # time embedding size use_moe: bool = False, num_experts: int = 8, moe_top_k: int = 2, ): super().__init__() # self.cross_attn2_weight = cross_attn2_weight use_ele_affine = True # ========================= Self-Attention ========================= self.use_self_attention = use_self_attention if self.use_self_attention: self.norm1 = norm_layer( hidden_size, elementwise_affine=use_ele_affine, eps=1e-6 ) self.attn1 = Attention( hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, norm_layer=qk_norm_layer, ) # ========================= Add ========================= # Simply use add like SDXL. self.timested_modulate = timested_modulate if self.timested_modulate: self.default_modulation = nn.Sequential( nn.SiLU(), nn.Linear(c_emb_size, hidden_size, bias=True) ) # ========================= Cross-Attention ========================= self.use_cross_attention = use_cross_attention if self.use_cross_attention: self.norm2 = norm_layer( hidden_size, elementwise_affine=use_ele_affine, eps=1e-6 ) self.attn2 = CrossAttention( hidden_size, encoder_hidden_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, norm_layer=qk_norm_layer, with_decoupled_ca=False, ) self.use_cross_attention_2 = use_cross_attention_2 if self.use_cross_attention_2: self.norm2_2 = norm_layer( hidden_size, elementwise_affine=use_ele_affine, eps=1e-6 ) self.attn2_2 = CrossAttention( hidden_size, encoder_hidden2_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, norm_layer=qk_norm_layer, with_decoupled_ca=with_decoupled_ca, decoupled_ca_dim=decoupled_ca_dim, decoupled_ca_weight=decoupled_ca_weight, ) # ========================= FFN ========================= self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6) self.use_moe = use_moe if self.use_moe: print("using moe") self.moe = MoEBlock( hidden_size, num_experts=num_experts, moe_top_k=moe_top_k, dropout=0.0, activation_fn="gelu", final_dropout=False, ff_inner_dim=int(hidden_size * 4.0), ff_bias=True, ) else: self.mlp = MLP(width=hidden_size) # ========================= skip FFN ========================= if skip_connection: self.skip_norm = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6) self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) else: self.skip_linear = None def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states_2: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, skip_value: torch.Tensor = None, ): # skip connection if self.skip_linear is not None: cat = torch.cat([skip_value, hidden_states], dim=-1) hidden_states = self.skip_linear(cat) hidden_states = self.skip_norm(hidden_states) # local global attn (self-attn) if self.timested_modulate: shift_msa = self.default_modulation(temb).unsqueeze(dim=1) hidden_states = hidden_states + shift_msa if self.use_self_attention: attn_output = self.attn1(self.norm1(hidden_states)) hidden_states = hidden_states + attn_output # image cross attn if self.use_cross_attention: original_cross_out = self.attn2( self.norm2(hidden_states), encoder_hidden_states, ) # added local-global cross attn # 2. Cross-Attention if self.use_cross_attention_2: cross_out_2 = self.attn2_2( self.norm2_2(hidden_states), encoder_hidden_states_2, ) hidden_states = ( hidden_states + (original_cross_out if self.use_cross_attention else 0) + (cross_out_2 if self.use_cross_attention_2 else 0) ) # FFN Layer mlp_inputs = self.norm3(hidden_states) if self.use_moe: hidden_states = hidden_states + self.moe(mlp_inputs) else: hidden_states = hidden_states + self.mlp(mlp_inputs) return hidden_states class FinalLayer(nn.Module): """ The final layer of HunYuanDiT. """ def __init__(self, final_hidden_size, out_channels): super().__init__() self.final_hidden_size = final_hidden_size self.norm_final = nn.LayerNorm( final_hidden_size, elementwise_affine=True, eps=1e-6 ) self.linear = nn.Linear(final_hidden_size, out_channels, bias=True) def forward(self, x): x = self.norm_final(x) x = x[:, 1:] x = self.linear(x) return x class PartFormerDITPlain(nn.Module): def __init__( self, input_size=1024, in_channels=4, hidden_size=1024, use_self_attention=True, use_cross_attention=True, use_cross_attention_2=True, encoder_hidden_dim=1024, # cross-attn encoder_hidden_states dim encoder_hidden2_dim=1024, # cross-attn 2 encoder_hidden_states dim depth=24, num_heads=16, qk_norm=False, qkv_bias=True, norm_type="layer", qk_norm_type="rms", with_decoupled_ca=False, decoupled_ca_dim=16, decoupled_ca_weight=1.0, use_pos_emb=False, # use_attention_pooling=True, guidance_cond_proj_dim=None, num_moe_layers: int = 6, num_experts: int = 8, moe_top_k: int = 2, **kwargs, ): super().__init__() self.input_size = input_size self.depth = depth self.in_channels = in_channels self.out_channels = in_channels self.num_heads = num_heads self.hidden_size = hidden_size self.norm = nn.LayerNorm if norm_type == "layer" else nn.RMSNorm self.qk_norm = nn.RMSNorm if qk_norm_type == "rms" else nn.LayerNorm # embedding self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True) self.t_embedder = TimestepEmbedder( hidden_size, hidden_size * 4, cond_proj_dim=guidance_cond_proj_dim ) # Will use fixed sin-cos embedding: self.use_pos_emb = use_pos_emb if self.use_pos_emb: self.register_buffer("pos_embed", torch.zeros(1, input_size, hidden_size)) pos = np.arange(self.input_size, dtype=np.float32) pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1], pos) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # self.use_attention_pooling = use_attention_pooling # if use_attention_pooling: # self.pooler = AttentionPool( # self.text_len, encoder_hidden_dim, num_heads=8, output_dim=1024 # ) # self.extra_embedder = nn.Sequential( # nn.Linear(1024, hidden_size * 4), # nn.SiLU(), # nn.Linear(hidden_size * 4, hidden_size, bias=True), # ) # for part embedding self.use_bbox_cond = kwargs.get("use_bbox_cond", False) if self.use_bbox_cond: self.bbox_conditioner = BboxEmbedder( out_size=hidden_size, num_freqs=kwargs.get("num_freqs", 8), ) self.use_part_embed = kwargs.get("use_part_embed", False) if self.use_part_embed: self.valid_num = kwargs.get("valid_num", 50) self.part_embed = nn.Parameter(torch.randn(self.valid_num, hidden_size)) # zero init part_embed self.part_embed.data.zero_() # transformer blocks self.blocks = nn.ModuleList([ PartFormerDitBlock( hidden_size, num_heads, use_self_attention=use_self_attention, use_cross_attention=use_cross_attention, use_cross_attention_2=use_cross_attention_2, encoder_hidden_dim=encoder_hidden_dim, # cross-attn encoder_hidden_states dim encoder_hidden2_dim=encoder_hidden2_dim, # cross-attn 2 encoder_hidden_states dim # cross_attn2_weight=cross_attn2_weight, qkv_bias=qkv_bias, qk_norm=qk_norm, norm_layer=self.norm, qk_norm_layer=self.qk_norm, with_decoupled_ca=with_decoupled_ca, decoupled_ca_dim=decoupled_ca_dim, decoupled_ca_weight=decoupled_ca_weight, skip_connection=layer > depth // 2, use_moe=True if depth - layer <= num_moe_layers else False, num_experts=num_experts, moe_top_k=moe_top_k, ) for layer in range(depth) ]) # set local-global processor for layer, block in enumerate(self.blocks): if hasattr(block, "attn1") and (layer + 1) % 2 == 0: block.attn1.processor = LocalGlobalProcessor(use_global=True) self.depth = depth self.final_layer = FinalLayer(hidden_size, self.out_channels) def forward(self, x, t, contexts: dict, **kwargs): """ x: [B, N, C] t: [B] contexts: dict image_context: [B, K*ni, C] geo_context: [B, K*ng, C] or [B, K*ng, C*2] aabb: [B, K, 2, 3] num_tokens: [B, N] N = K * num_tokens For parts pretrain : K = 1 """ # prepare input aabb: torch.Tensor = kwargs.get("aabb", None) # image_context = contexts.get("image_un_cond", None) object_context = contexts.get("obj_cond", None) geo_context = contexts.get("geo_cond", None) num_tokens: torch.Tensor = kwargs.get("num_tokens", None) # timeembedding and input projection t = self.t_embedder(t, condition=kwargs.get("guidance_cond")) x = self.x_embedder(x) if self.use_pos_emb: pos_embed = self.pos_embed.to(x.dtype) x = x + pos_embed # c is time embedding (adding pooling context or not) # if self.use_attention_pooling: # # TODO: attention_pooling for all contexts # extra_vec = self.pooler(image_context, None) # c = t + self.extra_embedder(extra_vec) # [B, D] # else: # c = t c = t # bounding box if self.use_bbox_cond: center_extent = torch.cat( [torch.mean(aabb, dim=-2), aabb[..., 1, :] - aabb[..., 0, :]], dim=-1 ) bbox_embeds = self.bbox_conditioner(center_extent) # TODO: now only support batch_size=1 bbox_embeds = torch.repeat_interleave( bbox_embeds, repeats=num_tokens[0], dim=1 ) x = x + bbox_embeds # part id embedding if self.use_part_embed: num_parts = aabb.shape[1] random_idx = torch.randperm(self.valid_num)[:num_parts] part_embeds = self.part_embed[random_idx].unsqueeze(1) # import pdb # pdb.set_trace() x = x + part_embeds x = torch.cat([c, x], dim=1) skip_value_list = [] for layer, block in enumerate(self.blocks): skip_value = None if layer <= self.depth // 2 else skip_value_list.pop() x = block( hidden_states=x, # encoder_hidden_states=image_context, encoder_hidden_states=object_context, encoder_hidden_states_2=geo_context, temb=c, skip_value=skip_value, ) if layer < self.depth // 2: skip_value_list.append(x) x = self.final_layer(x) return x