Spaces:
Paused
Paused
| # encoding = 'utf-8' | |
| import os.path as osp | |
| import math | |
| from rich.progress import track | |
| from omegaconf import OmegaConf | |
| import torch | |
| import torch.nn as nn | |
| from .talking_head_dit import TalkingHeadDiT_models | |
| import sys | |
| from ..schedulers.scheduling_ddim import DDIMScheduler | |
| from ..schedulers.flow_matching import ModelSamplingDiscreteFlow | |
| sys.path.append(osp.dirname(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))))) | |
| scheduler_map = { | |
| "ddim": DDIMScheduler, | |
| # "ddpm": DiffusionSchedule, | |
| "flow_matching": ModelSamplingDiscreteFlow | |
| } | |
| lip_dims=[18, 19, 20, 36, 37, 38, 42, 43, 44, 51, 52, 53, 57, 58, 59, 60, 61, 62] | |
| class MotionDiffusion(nn.Module): | |
| def __init__(self, config, device="cuda", dtype=torch.float32, smo_wsize=3, loss_type="l2"): | |
| super().__init__() | |
| self.config = config | |
| self.smo_wsize = smo_wsize | |
| print(f"================================== Init Motion GeneratorV2 ==================================") | |
| print(OmegaConf.to_yaml(self.config)) | |
| motion_gen_config = config.motion_generator | |
| motion_gen_params = motion_gen_config.params | |
| audio_proj_config = config.audio_projector | |
| audio_proj_params = audio_proj_config.params | |
| scheduler_config = config.noise_scheduler | |
| scheduler_params = scheduler_config.params | |
| self.device = device | |
| # init motion generator | |
| self.talking_head_dit = TalkingHeadDiT_models[config.model_name]( | |
| input_dim = motion_gen_params.input_dim * 2, | |
| output_dim = motion_gen_params.output_dim, | |
| seq_len = motion_gen_params.n_pred_frames, | |
| audio_unit_len = audio_proj_params.sequence_length, | |
| audio_blocks = audio_proj_params.blocks, | |
| audio_dim = audio_proj_params.audio_feat_dim, | |
| audio_tokens = audio_proj_params.context_tokens, | |
| audio_embedder_type = audio_proj_params.audio_embedder_type, | |
| audio_cond_dim = audio_proj_params.audio_cond_dim, | |
| norm_type = motion_gen_params.norm_type, | |
| qk_norm = motion_gen_params.qk_norm, | |
| exp_dim = motion_gen_params.exp_dim | |
| ) | |
| self.input_dim = motion_gen_params.input_dim | |
| self.exp_dim = motion_gen_params.exp_dim | |
| self.audio_feat_dim = audio_proj_params.audio_feat_dim | |
| self.audio_seq_len = audio_proj_params.sequence_length | |
| self.audio_blocks = audio_proj_params.blocks | |
| self.audio_margin = (audio_proj_params.sequence_length - 1) // 2 | |
| self.indices = ( | |
| torch.arange(2 * self.audio_margin + 1) - self.audio_margin | |
| ).unsqueeze(0) # Generates [-2, -1, 0, 1, 2], size 1 x (2*self.audio_margin+1) | |
| self.n_prev_frames = motion_gen_params.n_prev_frames | |
| self.n_pred_frames = motion_gen_params.n_pred_frames | |
| # init diffusion schedule | |
| self.scheduler = scheduler_map[scheduler_config.type]( | |
| num_train_timesteps = scheduler_params.num_train_timesteps, | |
| beta_start = scheduler_params.beta_start, | |
| beta_end = scheduler_params.beta_end, | |
| beta_schedule = scheduler_params.mode, | |
| prediction_type = scheduler_config.sample_mode, | |
| time_shifting = scheduler_params.time_shifting, | |
| ) | |
| self.scheduler_type = scheduler_config.type | |
| self.eta = scheduler_params.eta | |
| self.scheduler.set_timesteps(scheduler_params.num_inference_steps, device=self.device) | |
| self.timesteps = self.scheduler.timesteps | |
| print(f"time steps: {self.timesteps}") | |
| self.sample_mode = scheduler_config.sample_mode | |
| assert (self.sample_mode in ["noise", "sample"], f"Unknown sample mode {self.sample_mode}, should be noise or sample") | |
| # init other params | |
| self.audio_drop_ratio = config.train.audio_drop_ratio | |
| self.pre_drop_ratio = config.train.pre_drop_ratio | |
| self.null_audio_feat = nn.Parameter( | |
| torch.randn(1, 1, 1, 1, self.audio_feat_dim), | |
| requires_grad=True | |
| ).to(device=self.device, dtype=dtype) | |
| self.null_motion_feat = nn.Parameter( | |
| torch.randn(1, 1, self.input_dim), | |
| requires_grad=True | |
| ).to(device=self.device, dtype=dtype) | |
| # for segments fusion | |
| self.overlap_len = min(16, self.n_pred_frames - 16) | |
| self.fuse_alpha = torch.arange(self.overlap_len, device=self.device, dtype=dtype).reshape(1, -1, 1) / self.overlap_len | |
| self.dtype = dtype | |
| self.loss_type = loss_type | |
| total_params = sum(p.numel() for p in self.parameters()) | |
| print('Number of parameter: % .4fM' % (total_params / 1e6)) | |
| print(f"================================== init Motion GeneratorV2: Done ==================================") | |
| def _smooth(self, motion): | |
| # motion, B x L x D | |
| if self.smo_wsize <= 1: | |
| return motion | |
| new_motion = motion.clone() | |
| n = motion.shape[1] | |
| half_k = self.smo_wsize // 2 | |
| for i in range(n): | |
| ss = max(0, i - half_k) | |
| ee = min(n, i + half_k + 1) | |
| # only smooth head pose motion | |
| motion[:, i, self.exp_dim:] = torch.mean(new_motion[:, ss:ee, self.exp_dim:], dim=1) | |
| return motion | |
| def _fuse(self, prev_motion, cur_motion): | |
| r1 = prev_motion[:, -self.overlap_len:] | |
| r2 = cur_motion[:, :self.overlap_len] | |
| r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha | |
| prev_motion[:, -self.overlap_len:] = r_fuse # fuse last | |
| return prev_motion | |
| def sample_subclip( | |
| self, | |
| audio, | |
| ref_kp, | |
| prev_motion, | |
| emo=None, | |
| cfg_scale=1.15, | |
| init_latents=None, | |
| dynamic_threshold = None | |
| ): | |
| # prepare audio feat | |
| batch_size = audio.shape[0] | |
| audio = audio.to(self.device) | |
| if audio.ndim == 4: | |
| audio = audio.unsqueeze(2) | |
| # reference keypoints | |
| ref_kp = ref_kp.view(batch_size, 1, -1) | |
| # cfg | |
| if cfg_scale > 1: | |
| uncond_audio = self.null_audio_feat.expand( | |
| batch_size, self.n_pred_frames, self.audio_seq_len, self.audio_blocks, -1 | |
| ) | |
| audio = torch.cat([uncond_audio,audio], dim=0) | |
| ref_kp = torch.cat([ref_kp] * 2, dim=0) | |
| if emo is not None: | |
| uncond_emo = torch.Tensor([self.talking_head_dit.num_emo_class]).long().to(self.device) | |
| emo = torch.cat([uncond_emo,emo], dim=0) | |
| ref_kp = ref_kp.repeat(1, audio.shape[1], 1) # B, L, kD | |
| # prepare noisy motion | |
| if init_latents is None: | |
| latents = torch.randn((batch_size, self.n_pred_frames, self.input_dim)).to(self.device) | |
| else: | |
| latents = init_latents | |
| prev_motion = prev_motion.expand_as(latents).to(dtype=self.dtype) | |
| latents = latents.to(dtype=self.dtype) | |
| audio = audio.to(dtype=self.dtype) | |
| ref_kp = ref_kp.to(dtype=self.dtype) | |
| for t in track(self.timesteps, description='🚀Denosing', total=len(self.timesteps)): | |
| motion_in = torch.cat([prev_motion, latents], dim=-1) | |
| step_in = torch.tensor([t] * batch_size, device=self.device, dtype=self.dtype) | |
| if cfg_scale > 1: | |
| motion_in = torch.cat([motion_in] * 2, dim=0) | |
| step_in = torch.cat([step_in] * 2, dim=0) | |
| # predict | |
| pred = self.talking_head_dit( | |
| motion = motion_in, | |
| times = step_in, | |
| audio = audio, | |
| emo = emo, | |
| audio_cond = ref_kp | |
| ) | |
| if dynamic_threshold: | |
| dt_ratio, dt_min, dt_max = dynamic_threshold | |
| abs_results = pred.reshape(batch_size * 2, -1).abs() | |
| s = torch.quantile(abs_results, dt_ratio, dim=1) | |
| s = torch.clamp(s, min=dt_min, max=dt_max) | |
| s = s[..., None, None] | |
| pred = torch.clamp(pred, min=-s, max=s) | |
| # CFG | |
| if cfg_scale > 1: | |
| # uncond_pred, emo_cond_pred, all_cond_pred = pred.chunk(3, dim=0) | |
| # pred = uncond_pred + 8 * (emo_cond_pred - uncond_pred) + 1.2 * (all_cond_pred - emo_cond_pred) | |
| uncond_pred, cond_pred = pred.chunk(2, dim=0) | |
| pred = uncond_pred + cfg_scale * (cond_pred - uncond_pred) | |
| # Step | |
| latents = self.scheduler.step(pred, t, latents, eta=self.eta, return_dict=False)[0] | |
| self.talking_head_dit.bank=[] | |
| return latents | |
| def sample(self, audio, ref_kp, prev_motion, cfg_scale=1.15, audio_pad_mode="zero", emo=None,dynamic_threshold=None): | |
| # prev_motion, B, 1, D | |
| # for inference with any length audio | |
| # crop audio into n_subdivision according to n_pred_frames | |
| clip_len = audio.shape[0] | |
| stride = self.n_pred_frames - self.overlap_len | |
| if clip_len <= self.n_pred_frames: | |
| n_subdivision = 1 | |
| else: | |
| n_subdivision = math.ceil((clip_len - self.n_pred_frames) / stride) + 1 | |
| # padding | |
| n_padding_frames = self.n_pred_frames + stride * (n_subdivision - 1) - clip_len | |
| if n_padding_frames > 0: | |
| padding_value = 0 | |
| if audio_pad_mode == 'zero': | |
| padding_value = torch.zeros_like(audio[-1:]) | |
| elif audio_pad_mode == 'replicate': | |
| padding_value = audio[-1:] | |
| else: | |
| raise ValueError(f'Unknown pad mode: {audio_pad_mode}') | |
| audio = torch.cat( | |
| [audio[:1]] * self.audio_margin \ | |
| + [audio] + [padding_value] * n_padding_frames \ | |
| + [audio[-1:]] * self.audio_margin, | |
| dim=0 | |
| ) | |
| center_indices = torch.arange( | |
| self.audio_margin, | |
| audio.shape[0] - self.audio_margin | |
| ).unsqueeze(1) + self.indices | |
| audio_tensor = audio[center_indices] # T, L, b, aD | |
| # add reference keypoints | |
| motion_lst = [] | |
| #init_latents = torch.randn((1, self.n_pred_frames, self.motion_dim)).to(device=self.device) | |
| init_latents = None | |
| # emotion label | |
| if emo is not None: | |
| emo = torch.Tensor([emo]).long().to(self.device) | |
| start_idx = 0 | |
| for i in range(0, n_subdivision): | |
| print(f"Sample subclip {i+1}/{n_subdivision}") | |
| end_idx = start_idx + self.n_pred_frames | |
| audio_segment = audio_tensor[start_idx: end_idx].unsqueeze(0) | |
| start_idx += stride | |
| # debug | |
| #print(f"scale:") | |
| motion_segment = self.sample_subclip( | |
| audio = audio_segment, | |
| ref_kp = ref_kp, | |
| prev_motion = prev_motion, | |
| emo = emo, | |
| cfg_scale = cfg_scale, | |
| init_latents = init_latents, | |
| dynamic_threshold = dynamic_threshold | |
| ) | |
| # smooth | |
| motion_segment = self._smooth(motion_segment) | |
| # update prev motion | |
| prev_motion = motion_segment[:, stride-1:stride].clone() | |
| # save results | |
| motion_coef = motion_segment | |
| if i == n_subdivision - 1 and n_padding_frames > 0: | |
| motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames | |
| if len(motion_lst) > 0: | |
| # fuse segments | |
| motion_lst[-1] = self._fuse(motion_lst[-1], motion_coef) | |
| motion_lst.append(motion_coef[:, self.overlap_len:]) | |
| else: | |
| motion_lst.append(motion_coef) | |
| motion = torch.cat(motion_lst, dim=1) | |
| # smooth for full clip | |
| motion = self._smooth(motion) | |
| motion = motion.squeeze() | |
| return motion.float() | |