| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import math |
| import time |
| from .blocks import FinalLayer |
| from .blocks import MMDoubleStreamBlock as DiTBlock2 |
| from .blocks import MMSingleStreamBlock as DiTBlock |
| from .blocks import CrossDiTBlock as DiTBlock3 |
| from .blocks import MMfourStreamBlock as DiTBlock4 |
| |
| from .posemb_layers import apply_rotary_emb, get_1d_rotary_pos_embed |
| from .embedders import TimestepEmbedder, MotionEmbedder, AudioEmbedder, ConditionAudioEmbedder, SimpleAudioEmbedder, LabelEmbedder |
| from einops import rearrange, repeat |
| audio_embedder_map = { |
| "normal": AudioEmbedder, |
| "cond": ConditionAudioEmbedder, |
| "simple": SimpleAudioEmbedder |
| } |
| import matplotlib.pyplot as plt |
| from sklearn.manifold import TSNE |
| class TalkingHeadDiT(nn.Module): |
| """ |
| Diffusion model with a Transformer backbone. |
| """ |
| def __init__( |
| self, |
| input_dim=265, |
| output_dim =265, |
| seq_len=80, |
| audio_unit_len=5, |
| audio_blocks=12, |
| audio_dim=768, |
| audio_tokens = 1, |
| hidden_size=1152, |
| depth=28, |
| num_heads=16, |
| mlp_ratio=4.0, |
| audio_embedder_type="normal", |
| audio_cond_dim = 63, |
| norm_type="rms_norm", |
| qk_norm="rms_norm", |
| **kwargs |
| ): |
| super().__init__() |
| |
| self.num_emo_class = 8 |
| self.emo_drop_prob = 0.1 |
|
|
| self.num_heads = num_heads |
| self.out_channels = output_dim |
|
|
| self.motion_embedder = MotionEmbedder(input_dim, hidden_size) |
| self.identity_embedder=MotionEmbedder(audio_cond_dim, hidden_size) |
| self.time_embedder = TimestepEmbedder(hidden_size) |
| self.audio_embedder = audio_embedder_map['normal']( |
| seq_len = audio_unit_len, |
| blocks = audio_blocks, |
| channels = audio_dim, |
| intermediate_dim = hidden_size, |
| output_dim = hidden_size, |
| context_tokens = audio_tokens, |
| input_len = seq_len, |
| condition_dim = audio_cond_dim, |
| norm_type = norm_type, |
| |
| |
| ) |
| self.dim=hidden_size//num_heads |
| |
| self.emo_embedder = LabelEmbedder(num_classes=self.num_emo_class, hidden_size=hidden_size, dropout_prob=self.emo_drop_prob) |
| |
| |
| |
| self.blocks4 = nn.ModuleList([ |
| DiTBlock4( |
| hidden_size, num_heads, |
| mlp_ratio=mlp_ratio, |
| norm_type=norm_type, |
| qk_norm=qk_norm |
| ) for _ in range(3) |
| ]) |
| self.blocks2 = nn.ModuleList([ |
| DiTBlock2( |
| hidden_size, num_heads, |
| mlp_ratio=mlp_ratio, |
| norm_type=norm_type, |
| qk_norm=qk_norm |
| ) for _ in range(6) |
| ]) |
| self.blocks=nn.ModuleList([ |
| DiTBlock( |
| hidden_size, num_heads, |
| mlp_ratio=mlp_ratio, |
| norm_type=norm_type, |
| qk_norm=qk_norm |
| ) for _ in range(12) |
| ]) |
| self.final_layer = FinalLayer(hidden_size, self.out_channels, norm_type=norm_type) |
| self.initialize_weights() |
| self.bank=[] |
| def initialize_weights(self): |
| |
| |
| |
|
|
| |
| self.motion_embedder.initialize_weights() |
| self.identity_embedder.initialize_weights() |
| |
| self.audio_embedder.initialize_weights() |
|
|
| |
| self.emo_embedder.initialize_weights() |
|
|
| |
| self.time_embedder.initialize_weights() |
| |
| |
| for block in self.blocks: |
| block.initialize_weights() |
| for block in self.blocks2: |
| block.initialize_weights() |
| for block in self.blocks4: |
| block.initialize_weights() |
| |
| |
| def cal_sync_loss(self, audio_embedding, mouth_embedding, label): |
| if isinstance(label, torch.Tensor): |
| gt_d = label.float().view(-1,1).to(audio_embedding.device) |
| else: |
| gt_d = (torch.ones([audio_embedding.shape[0],1]) * label).float().to(audio_embedding.device) |
| d = nn.functional.cosine_similarity(audio_embedding, mouth_embedding) |
| loss = self.logloss(d.unsqueeze(1), gt_d) |
| return loss, d |
|
|
| def forward(self, motion, times, audio, emo, audio_cond,mask=None): |
| """ |
| Forward pass of Talking Head DiT. |
| motion: (B, N, xD) tensor of moton features inputs (head motion, emotion, etc.) |
| time: (B,) tensor of diffusion timesteps |
| audio: (B, N, M, yD) tensor of audio features, (batch_size, video_length, blocks, channels). |
| cond: (B, N, cD) tensor of conditional features |
| audio_cond: (B, N, zD) or (B, zD) tensor of audio conditional features |
| """ |
| |
| motion_embeds = self.motion_embedder(motion) |
| _,seq_len,_=motion.shape |
| time_embeds = self.time_embedder(times) |
| cache=True |
| if cache: |
| |
| emo_embeds = self.emo_embedder(emo, self.training) |
| audio_cond=audio_cond.mean(1) |
| audio_cond_embeds = self.identity_embedder(audio_cond) |
| |
| |
| freqs_cos, freqs_sin = get_1d_rotary_pos_embed(self.dim, seq_len,theta=256, use_real=True, theta_rescale_factor=1) |
| freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None |
| audio_embeds = self.audio_embedder(audio) |
| |
| M=audio_embeds.shape[2] |
| audio_embeds = rearrange(audio_embeds, "b n m d -> b (n m) d") |
| |
| c = time_embeds+emo_embeds |
| |
|
|
| freqs_cos2=rearrange(freqs_cos.unsqueeze(0).repeat(M,1,1), "n m d -> (n m) d") |
| freqs_sin2=rearrange(freqs_sin.unsqueeze(0).repeat(M,1,1),"n m d -> (n m) d") |
| freqs_cis2 = (freqs_cos2, freqs_sin2) if freqs_cos2 is not None else None |
|
|
| freqs_cos3=rearrange(freqs_cos.unsqueeze(0).repeat(3*M,1,1), "n m d -> (n m) d") |
| freqs_sin3=rearrange(freqs_sin.unsqueeze(0).repeat(3*M,1,1),"n m d -> (n m) d") |
| freqs_cis3 = (freqs_cos3, freqs_sin3) if freqs_cos2 is not None else None |
| |
| |
| |
| emo_embeds=emo_embeds.unsqueeze(1).repeat(1,seq_len,1) |
| audio_cond_embeds=audio_cond_embeds.unsqueeze(1).repeat(1,seq_len,1) |
| for block in (self.blocks4): |
| motion_embeds,audio_embeds,emo_embeds,audio_cond_embeds = block(motion_embeds, c, audio_embeds,emo_embeds,audio_cond_embeds,mask,freqs_cis,freqs_cis2,causal=False) |
| audio_embeds=torch.cat((audio_embeds,emo_embeds,audio_cond_embeds), 1) |
| for block in self.blocks2: |
| motion_embeds,audio_embeds= block(seq_len,motion_embeds, c, audio_embeds,mask,freqs_cis,freqs_cis3,causal=False) |
| motion_embeds=torch.cat((motion_embeds, audio_embeds), 1) |
| for block in self.blocks: |
| motion_embeds = block(seq_len,motion_embeds, c,mask,freqs_cis,freqs_cis3,causal=False) |
| motion_embeds=motion_embeds[:,:seq_len,:] |
| out = self.final_layer(motion_embeds, c) |
| |
| return out |
|
|
| def forward_with_cfg(self, motion, time, audio, cfg_scale, emo=None, audio_cond=None): |
| """ |
| Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. |
| """ |
| pass |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
| def TalkingHeadDiT_XL(**kwargs): |
| return TalkingHeadDiT(depth=28, hidden_size=1152, num_heads=16, **kwargs) |
|
|
| def TalkingHeadDiT_L(**kwargs): |
| return TalkingHeadDiT(depth=24, hidden_size=1024, num_heads=16, **kwargs) |
|
|
| def TalkingHeadDiT_B(**kwargs): |
| return TalkingHeadDiT(depth=12, hidden_size=768, num_heads=12, **kwargs) |
| def TalkingHeadDiT_MM(**kwargs): |
| return TalkingHeadDiT(depth=6, hidden_size=768, num_heads=12, **kwargs) |
| def TalkingHeadDiT_S(**kwargs): |
| return TalkingHeadDiT(depth=12, hidden_size=384, num_heads=6, **kwargs) |
|
|
| def TalkingHeadDiT_T(**kwargs): |
| return TalkingHeadDiT(depth=6, hidden_size=256, num_heads=4, **kwargs) |
|
|
|
|
|
|
|
|
| TalkingHeadDiT_models = { |
| 'TalkingHeadDiT-XL': TalkingHeadDiT_XL, |
| 'TalkingHeadDiT-L': TalkingHeadDiT_L, |
| 'TalkingHeadDiT-MM': TalkingHeadDiT_MM, |
| 'TalkingHeadDiT-B': TalkingHeadDiT_B, |
| 'TalkingHeadDiT-S': TalkingHeadDiT_S, |
| 'TalkingHeadDiT-T': TalkingHeadDiT_T, |
| } |